add autodl
This commit is contained in:
12
AutoDL-Projects/xautodl/__init__.py
Normal file
12
AutoDL-Projects/xautodl/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
|
||||
#####################################################
|
||||
# An Automated Deep Learning Package to support #
|
||||
# research activities. #
|
||||
#####################################################
|
||||
|
||||
|
||||
def version():
|
||||
versions = ["0.9.9"] # 2021.06.01
|
||||
versions = ["1.0.0"] # 2021.08.14
|
||||
return versions[-1]
|
||||
20
AutoDL-Projects/xautodl/config_utils/__init__.py
Normal file
20
AutoDL-Projects/xautodl/config_utils/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
# general config related functions
|
||||
from .config_utils import load_config, dict2config, configure2str
|
||||
|
||||
# the args setting for different experiments
|
||||
from .basic_args import obtain_basic_args
|
||||
from .attention_args import obtain_attention_args
|
||||
from .random_baseline import obtain_RandomSearch_args
|
||||
from .cls_kd_args import obtain_cls_kd_args
|
||||
from .cls_init_args import obtain_cls_init_args
|
||||
from .search_single_args import obtain_search_single_args
|
||||
from .search_args import obtain_search_args
|
||||
|
||||
# for network pruning
|
||||
from .pruning_args import obtain_pruning_args
|
||||
|
||||
# utils for args
|
||||
from .args_utils import arg_str2bool
|
||||
12
AutoDL-Projects/xautodl/config_utils/args_utils.py
Normal file
12
AutoDL-Projects/xautodl/config_utils/args_utils.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import argparse
|
||||
|
||||
|
||||
def arg_str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
elif v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
32
AutoDL-Projects/xautodl/config_utils/attention_args.py
Normal file
32
AutoDL-Projects/xautodl/config_utils/attention_args.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
|
||||
def obtain_attention_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a classification model on typical image classification datasets.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||
parser.add_argument(
|
||||
"--model_config", type=str, help="The path to the model configuration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||
)
|
||||
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||
parser.add_argument("--att_channel", type=int, help=".")
|
||||
parser.add_argument("--att_spatial", type=str, help=".")
|
||||
parser.add_argument("--att_active", type=str, help=".")
|
||||
add_shared_args(parser)
|
||||
# Optimization options
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "save-path argument can not be None"
|
||||
return args
|
||||
44
AutoDL-Projects/xautodl/config_utils/basic_args.py
Normal file
44
AutoDL-Projects/xautodl/config_utils/basic_args.py
Normal file
@@ -0,0 +1,44 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||
##################################################
|
||||
import random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
|
||||
def obtain_basic_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a classification model on typical image classification datasets.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||
parser.add_argument(
|
||||
"--model_config", type=str, help="The path to the model configuration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||
)
|
||||
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||
parser.add_argument(
|
||||
"--model_source",
|
||||
type=str,
|
||||
default="normal",
|
||||
help="The source of model defination.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extra_model_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The extra model ckp file (help to indicate the searched architecture).",
|
||||
)
|
||||
add_shared_args(parser)
|
||||
# Optimization options
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "save-path argument can not be None"
|
||||
return args
|
||||
32
AutoDL-Projects/xautodl/config_utils/cls_init_args.py
Normal file
32
AutoDL-Projects/xautodl/config_utils/cls_init_args.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
|
||||
def obtain_cls_init_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a classification model on typical image classification datasets.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||
parser.add_argument(
|
||||
"--model_config", type=str, help="The path to the model configuration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||
)
|
||||
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||
parser.add_argument(
|
||||
"--init_checkpoint", type=str, help="The checkpoint path to the initial model."
|
||||
)
|
||||
add_shared_args(parser)
|
||||
# Optimization options
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "save-path argument can not be None"
|
||||
return args
|
||||
43
AutoDL-Projects/xautodl/config_utils/cls_kd_args.py
Normal file
43
AutoDL-Projects/xautodl/config_utils/cls_kd_args.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
|
||||
def obtain_cls_kd_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a classification model on typical image classification datasets.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||
parser.add_argument(
|
||||
"--model_config", type=str, help="The path to the model configuration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||
)
|
||||
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||
parser.add_argument(
|
||||
"--KD_checkpoint",
|
||||
type=str,
|
||||
help="The teacher checkpoint in knowledge distillation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--KD_alpha", type=float, help="The alpha parameter in knowledge distillation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--KD_temperature",
|
||||
type=float,
|
||||
help="The temperature parameter in knowledge distillation.",
|
||||
)
|
||||
# parser.add_argument('--KD_feature', type=float, help='Knowledge distillation at the feature level.')
|
||||
add_shared_args(parser)
|
||||
# Optimization options
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "save-path argument can not be None"
|
||||
return args
|
||||
135
AutoDL-Projects/xautodl/config_utils/config_utils.py
Normal file
135
AutoDL-Projects/xautodl/config_utils/config_utils.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
import os, json
|
||||
from os import path as osp
|
||||
from pathlib import Path
|
||||
from collections import namedtuple
|
||||
|
||||
support_types = ("str", "int", "bool", "float", "none")
|
||||
|
||||
|
||||
def convert_param(original_lists):
|
||||
assert isinstance(original_lists, list), "The type is not right : {:}".format(
|
||||
original_lists
|
||||
)
|
||||
ctype, value = original_lists[0], original_lists[1]
|
||||
assert ctype in support_types, "Ctype={:}, support={:}".format(ctype, support_types)
|
||||
is_list = isinstance(value, list)
|
||||
if not is_list:
|
||||
value = [value]
|
||||
outs = []
|
||||
for x in value:
|
||||
if ctype == "int":
|
||||
x = int(x)
|
||||
elif ctype == "str":
|
||||
x = str(x)
|
||||
elif ctype == "bool":
|
||||
x = bool(int(x))
|
||||
elif ctype == "float":
|
||||
x = float(x)
|
||||
elif ctype == "none":
|
||||
if x.lower() != "none":
|
||||
raise ValueError(
|
||||
"For the none type, the value must be none instead of {:}".format(x)
|
||||
)
|
||||
x = None
|
||||
else:
|
||||
raise TypeError("Does not know this type : {:}".format(ctype))
|
||||
outs.append(x)
|
||||
if not is_list:
|
||||
outs = outs[0]
|
||||
return outs
|
||||
|
||||
|
||||
def load_config(path, extra, logger):
|
||||
path = str(path)
|
||||
if hasattr(logger, "log"):
|
||||
logger.log(path)
|
||||
assert os.path.exists(path), "Can not find {:}".format(path)
|
||||
# Reading data back
|
||||
with open(path, "r") as f:
|
||||
data = json.load(f)
|
||||
content = {k: convert_param(v) for k, v in data.items()}
|
||||
assert extra is None or isinstance(
|
||||
extra, dict
|
||||
), "invalid type of extra : {:}".format(extra)
|
||||
if isinstance(extra, dict):
|
||||
content = {**content, **extra}
|
||||
Arguments = namedtuple("Configure", " ".join(content.keys()))
|
||||
content = Arguments(**content)
|
||||
if hasattr(logger, "log"):
|
||||
logger.log("{:}".format(content))
|
||||
return content
|
||||
|
||||
|
||||
def configure2str(config, xpath=None):
|
||||
if not isinstance(config, dict):
|
||||
config = config._asdict()
|
||||
|
||||
def cstring(x):
|
||||
return '"{:}"'.format(x)
|
||||
|
||||
def gtype(x):
|
||||
if isinstance(x, list):
|
||||
x = x[0]
|
||||
if isinstance(x, str):
|
||||
return "str"
|
||||
elif isinstance(x, bool):
|
||||
return "bool"
|
||||
elif isinstance(x, int):
|
||||
return "int"
|
||||
elif isinstance(x, float):
|
||||
return "float"
|
||||
elif x is None:
|
||||
return "none"
|
||||
else:
|
||||
raise ValueError("invalid : {:}".format(x))
|
||||
|
||||
def cvalue(x, xtype):
|
||||
if isinstance(x, list):
|
||||
is_list = True
|
||||
else:
|
||||
is_list, x = False, [x]
|
||||
temps = []
|
||||
for temp in x:
|
||||
if xtype == "bool":
|
||||
temp = cstring(int(temp))
|
||||
elif xtype == "none":
|
||||
temp = cstring("None")
|
||||
else:
|
||||
temp = cstring(temp)
|
||||
temps.append(temp)
|
||||
if is_list:
|
||||
return "[{:}]".format(", ".join(temps))
|
||||
else:
|
||||
return temps[0]
|
||||
|
||||
xstrings = []
|
||||
for key, value in config.items():
|
||||
xtype = gtype(value)
|
||||
string = " {:20s} : [{:8s}, {:}]".format(
|
||||
cstring(key), cstring(xtype), cvalue(value, xtype)
|
||||
)
|
||||
xstrings.append(string)
|
||||
Fstring = "{\n" + ",\n".join(xstrings) + "\n}"
|
||||
if xpath is not None:
|
||||
parent = Path(xpath).resolve().parent
|
||||
parent.mkdir(parents=True, exist_ok=True)
|
||||
if osp.isfile(xpath):
|
||||
os.remove(xpath)
|
||||
with open(xpath, "w") as text_file:
|
||||
text_file.write("{:}".format(Fstring))
|
||||
return Fstring
|
||||
|
||||
|
||||
def dict2config(xdict, logger):
|
||||
assert isinstance(xdict, dict), "invalid type : {:}".format(type(xdict))
|
||||
Arguments = namedtuple("Configure", " ".join(xdict.keys()))
|
||||
content = Arguments(**xdict)
|
||||
if hasattr(logger, "log"):
|
||||
logger.log("{:}".format(content))
|
||||
return content
|
||||
48
AutoDL-Projects/xautodl/config_utils/pruning_args.py
Normal file
48
AutoDL-Projects/xautodl/config_utils/pruning_args.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import os, sys, time, random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
|
||||
def obtain_pruning_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a classification model on typical image classification datasets.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||
parser.add_argument(
|
||||
"--model_config", type=str, help="The path to the model configuration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||
)
|
||||
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||
parser.add_argument(
|
||||
"--keep_ratio",
|
||||
type=float,
|
||||
help="The left channel ratio compared to the original network.",
|
||||
)
|
||||
parser.add_argument("--model_version", type=str, help="The network version.")
|
||||
parser.add_argument(
|
||||
"--KD_alpha", type=float, help="The alpha parameter in knowledge distillation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--KD_temperature",
|
||||
type=float,
|
||||
help="The temperature parameter in knowledge distillation.",
|
||||
)
|
||||
parser.add_argument("--Regular_W_feat", type=float, help="The .")
|
||||
parser.add_argument("--Regular_W_conv", type=float, help="The .")
|
||||
add_shared_args(parser)
|
||||
# Optimization options
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "save-path argument can not be None"
|
||||
assert (
|
||||
args.keep_ratio > 0 and args.keep_ratio <= 1
|
||||
), "invalid keep ratio : {:}".format(args.keep_ratio)
|
||||
return args
|
||||
44
AutoDL-Projects/xautodl/config_utils/random_baseline.py
Normal file
44
AutoDL-Projects/xautodl/config_utils/random_baseline.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import os, sys, time, random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
|
||||
def obtain_RandomSearch_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a classification model on typical image classification datasets.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||
parser.add_argument(
|
||||
"--expect_flop", type=float, help="The expected flop keep ratio."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch_nums",
|
||||
type=int,
|
||||
help="The maximum number of running random arch generating..",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config", type=str, help="The path to the model configuration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_mode",
|
||||
type=str,
|
||||
choices=["random", "fix"],
|
||||
help="The path to the optimizer configuration",
|
||||
)
|
||||
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||
add_shared_args(parser)
|
||||
# Optimization options
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "save-path argument can not be None"
|
||||
# assert args.flop_ratio_min < args.flop_ratio_max, 'flop-ratio {:} vs {:}'.format(args.flop_ratio_min, args.flop_ratio_max)
|
||||
return args
|
||||
53
AutoDL-Projects/xautodl/config_utils/search_args.py
Normal file
53
AutoDL-Projects/xautodl/config_utils/search_args.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import os, sys, time, random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
|
||||
def obtain_search_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a classification model on typical image classification datasets.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||
parser.add_argument(
|
||||
"--model_config", type=str, help="The path to the model configuration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||
)
|
||||
parser.add_argument("--split_path", type=str, help="The split file path.")
|
||||
# parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.')
|
||||
parser.add_argument(
|
||||
"--gumbel_tau_max", type=float, help="The maximum tau for Gumbel."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gumbel_tau_min", type=float, help="The minimum tau for Gumbel."
|
||||
)
|
||||
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||
parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.")
|
||||
parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.")
|
||||
parser.add_argument(
|
||||
"--FLOP_tolerant", type=float, help="The tolerant range for FLOP."
|
||||
)
|
||||
# ablation studies
|
||||
parser.add_argument(
|
||||
"--ablation_num_select",
|
||||
type=int,
|
||||
help="The number of randomly selected channels.",
|
||||
)
|
||||
add_shared_args(parser)
|
||||
# Optimization options
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "save-path argument can not be None"
|
||||
assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None
|
||||
assert (
|
||||
args.FLOP_tolerant is not None and args.FLOP_tolerant > 0
|
||||
), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant)
|
||||
# assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure)
|
||||
# args.arch_para_pure = bool(args.arch_para_pure)
|
||||
return args
|
||||
48
AutoDL-Projects/xautodl/config_utils/search_single_args.py
Normal file
48
AutoDL-Projects/xautodl/config_utils/search_single_args.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import os, sys, time, random, argparse
|
||||
from .share_args import add_shared_args
|
||||
|
||||
|
||||
def obtain_search_single_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a classification model on typical image classification datasets.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||
parser.add_argument(
|
||||
"--model_config", type=str, help="The path to the model configuration"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim_config", type=str, help="The path to the optimizer configuration"
|
||||
)
|
||||
parser.add_argument("--split_path", type=str, help="The split file path.")
|
||||
parser.add_argument("--search_shape", type=str, help="The shape to be searched.")
|
||||
# parser.add_argument('--arch_para_pure', type=int, help='The architecture-parameter pure or not.')
|
||||
parser.add_argument(
|
||||
"--gumbel_tau_max", type=float, help="The maximum tau for Gumbel."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gumbel_tau_min", type=float, help="The minimum tau for Gumbel."
|
||||
)
|
||||
parser.add_argument("--procedure", type=str, help="The procedure basic prefix.")
|
||||
parser.add_argument("--FLOP_ratio", type=float, help="The expected FLOP ratio.")
|
||||
parser.add_argument("--FLOP_weight", type=float, help="The loss weight for FLOP.")
|
||||
parser.add_argument(
|
||||
"--FLOP_tolerant", type=float, help="The tolerant range for FLOP."
|
||||
)
|
||||
add_shared_args(parser)
|
||||
# Optimization options
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=2, help="Batch size for training."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "save-path argument can not be None"
|
||||
assert args.gumbel_tau_max is not None and args.gumbel_tau_min is not None
|
||||
assert (
|
||||
args.FLOP_tolerant is not None and args.FLOP_tolerant > 0
|
||||
), "invalid FLOP_tolerant : {:}".format(FLOP_tolerant)
|
||||
# assert args.arch_para_pure is not None, 'arch_para_pure is not None: {:}'.format(args.arch_para_pure)
|
||||
# args.arch_para_pure = bool(args.arch_para_pure)
|
||||
return args
|
||||
39
AutoDL-Projects/xautodl/config_utils/share_args.py
Normal file
39
AutoDL-Projects/xautodl/config_utils/share_args.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import os, sys, time, random, argparse
|
||||
|
||||
|
||||
def add_shared_args(parser):
|
||||
# Data Generation
|
||||
parser.add_argument("--dataset", type=str, help="The dataset name.")
|
||||
parser.add_argument("--data_path", type=str, help="The dataset name.")
|
||||
parser.add_argument(
|
||||
"--cutout_length", type=int, help="The cutout length, negative means not use."
|
||||
)
|
||||
# Printing
|
||||
parser.add_argument(
|
||||
"--print_freq", type=int, default=100, help="print frequency (default: 200)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print_freq_eval",
|
||||
type=int,
|
||||
default=100,
|
||||
help="print frequency (default: 200)",
|
||||
)
|
||||
# Checkpoints
|
||||
parser.add_argument(
|
||||
"--eval_frequency",
|
||||
type=int,
|
||||
default=1,
|
||||
help="evaluation frequency (default: 200)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dir", type=str, help="Folder to save checkpoints and log."
|
||||
)
|
||||
# Acceleration
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="number of data loading workers (default: 8)",
|
||||
)
|
||||
# Random Seed
|
||||
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
|
||||
16
AutoDL-Projects/xautodl/log_utils/__init__.py
Normal file
16
AutoDL-Projects/xautodl/log_utils/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
# every package does not rely on pytorch or tensorflow
|
||||
# I tried to list all dependency here: os, sys, time, numpy, (possibly) matplotlib
|
||||
##################################################
|
||||
from .logger import Logger, PrintLogger
|
||||
from .meter import AverageMeter
|
||||
from .time_utils import (
|
||||
time_for_file,
|
||||
time_string,
|
||||
time_string_short,
|
||||
time_print,
|
||||
convert_secs2time,
|
||||
)
|
||||
from .pickle_wrap import pickle_save, pickle_load
|
||||
173
AutoDL-Projects/xautodl/log_utils/logger.py
Normal file
173
AutoDL-Projects/xautodl/log_utils/logger.py
Normal file
@@ -0,0 +1,173 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from pathlib import Path
|
||||
import importlib, warnings
|
||||
import os, sys, time, numpy as np
|
||||
|
||||
if sys.version_info.major == 2: # Python 2.x
|
||||
from StringIO import StringIO as BIO
|
||||
else: # Python 3.x
|
||||
from io import BytesIO as BIO
|
||||
|
||||
if importlib.util.find_spec("tensorflow"):
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class PrintLogger(object):
|
||||
def __init__(self):
|
||||
"""Create a summary writer logging to log_dir."""
|
||||
self.name = "PrintLogger"
|
||||
|
||||
def log(self, string):
|
||||
print(string)
|
||||
|
||||
def close(self):
|
||||
print("-" * 30 + " close printer " + "-" * 30)
|
||||
|
||||
|
||||
class Logger(object):
|
||||
def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False):
|
||||
"""Create a summary writer logging to log_dir."""
|
||||
self.seed = int(seed)
|
||||
self.log_dir = Path(log_dir)
|
||||
self.model_dir = Path(log_dir) / "checkpoint"
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
if create_model_dir:
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
# self.meta_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
|
||||
|
||||
self.use_tf = bool(use_tf)
|
||||
self.tensorboard_dir = self.log_dir / (
|
||||
"tensorboard-{:}".format(time.strftime("%d-%h", time.gmtime(time.time())))
|
||||
)
|
||||
# self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h-at-%H:%M:%S', time.gmtime(time.time()) )))
|
||||
self.logger_path = self.log_dir / "seed-{:}-T-{:}.log".format(
|
||||
self.seed, time.strftime("%d-%h-at-%H-%M-%S", time.gmtime(time.time()))
|
||||
)
|
||||
self.logger_file = open(self.logger_path, "w")
|
||||
|
||||
if self.use_tf:
|
||||
self.tensorboard_dir.mkdir(mode=0o775, parents=True, exist_ok=True)
|
||||
self.writer = tf.summary.FileWriter(str(self.tensorboard_dir))
|
||||
else:
|
||||
self.writer = None
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(dir={log_dir}, use-tf={use_tf}, writer={writer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def path(self, mode):
|
||||
valids = ("model", "best", "info", "log", None)
|
||||
if mode is None:
|
||||
return self.log_dir
|
||||
elif mode == "model":
|
||||
return self.model_dir / "seed-{:}-basic.pth".format(self.seed)
|
||||
elif mode == "best":
|
||||
return self.model_dir / "seed-{:}-best.pth".format(self.seed)
|
||||
elif mode == "info":
|
||||
return self.log_dir / "seed-{:}-last-info.pth".format(self.seed)
|
||||
elif mode == "log":
|
||||
return self.log_dir
|
||||
else:
|
||||
raise TypeError("Unknow mode = {:}, valid modes = {:}".format(mode, valids))
|
||||
|
||||
def extract_log(self):
|
||||
return self.logger_file
|
||||
|
||||
def close(self):
|
||||
self.logger_file.close()
|
||||
if self.writer is not None:
|
||||
self.writer.close()
|
||||
|
||||
def log(self, string, save=True, stdout=False):
|
||||
if stdout:
|
||||
sys.stdout.write(string)
|
||||
sys.stdout.flush()
|
||||
else:
|
||||
print(string)
|
||||
if save:
|
||||
self.logger_file.write("{:}\n".format(string))
|
||||
self.logger_file.flush()
|
||||
|
||||
def scalar_summary(self, tags, values, step):
|
||||
"""Log a scalar variable."""
|
||||
if not self.use_tf:
|
||||
warnings.warn("Do set use-tensorflow installed but call scalar_summary")
|
||||
else:
|
||||
assert isinstance(tags, list) == isinstance(
|
||||
values, list
|
||||
), "Type : {:} vs {:}".format(type(tags), type(values))
|
||||
if not isinstance(tags, list):
|
||||
tags, values = [tags], [values]
|
||||
for tag, value in zip(tags, values):
|
||||
summary = tf.Summary(
|
||||
value=[tf.Summary.Value(tag=tag, simple_value=value)]
|
||||
)
|
||||
self.writer.add_summary(summary, step)
|
||||
self.writer.flush()
|
||||
|
||||
def image_summary(self, tag, images, step):
|
||||
"""Log a list of images."""
|
||||
import scipy
|
||||
|
||||
if not self.use_tf:
|
||||
warnings.warn("Do set use-tensorflow installed but call scalar_summary")
|
||||
return
|
||||
|
||||
img_summaries = []
|
||||
for i, img in enumerate(images):
|
||||
# Write the image to a string
|
||||
try:
|
||||
s = StringIO()
|
||||
except:
|
||||
s = BytesIO()
|
||||
scipy.misc.toimage(img).save(s, format="png")
|
||||
|
||||
# Create an Image object
|
||||
img_sum = tf.Summary.Image(
|
||||
encoded_image_string=s.getvalue(),
|
||||
height=img.shape[0],
|
||||
width=img.shape[1],
|
||||
)
|
||||
# Create a Summary value
|
||||
img_summaries.append(
|
||||
tf.Summary.Value(tag="{}/{}".format(tag, i), image=img_sum)
|
||||
)
|
||||
|
||||
# Create and write Summary
|
||||
summary = tf.Summary(value=img_summaries)
|
||||
self.writer.add_summary(summary, step)
|
||||
self.writer.flush()
|
||||
|
||||
def histo_summary(self, tag, values, step, bins=1000):
|
||||
"""Log a histogram of the tensor of values."""
|
||||
if not self.use_tf:
|
||||
raise ValueError("Do not have tensorflow")
|
||||
import tensorflow as tf
|
||||
|
||||
# Create a histogram using numpy
|
||||
counts, bin_edges = np.histogram(values, bins=bins)
|
||||
|
||||
# Fill the fields of the histogram proto
|
||||
hist = tf.HistogramProto()
|
||||
hist.min = float(np.min(values))
|
||||
hist.max = float(np.max(values))
|
||||
hist.num = int(np.prod(values.shape))
|
||||
hist.sum = float(np.sum(values))
|
||||
hist.sum_squares = float(np.sum(values**2))
|
||||
|
||||
# Drop the start of the first bin
|
||||
bin_edges = bin_edges[1:]
|
||||
|
||||
# Add bin edges and counts
|
||||
for edge in bin_edges:
|
||||
hist.bucket_limit.append(edge)
|
||||
for c in counts:
|
||||
hist.bucket.append(c)
|
||||
|
||||
# Create and write Summary
|
||||
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
|
||||
self.writer.add_summary(summary, step)
|
||||
self.writer.flush()
|
||||
120
AutoDL-Projects/xautodl/log_utils/meter.py
Normal file
120
AutoDL-Projects/xautodl/log_utils/meter.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0.0
|
||||
self.avg = 0.0
|
||||
self.sum = 0.0
|
||||
self.count = 0.0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(val={val}, avg={avg}, count={count})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
|
||||
class RecorderMeter(object):
|
||||
"""Computes and stores the minimum loss value and its epoch index"""
|
||||
|
||||
def __init__(self, total_epoch):
|
||||
self.reset(total_epoch)
|
||||
|
||||
def reset(self, total_epoch):
|
||||
assert total_epoch > 0, "total_epoch should be greater than 0 vs {:}".format(
|
||||
total_epoch
|
||||
)
|
||||
self.total_epoch = total_epoch
|
||||
self.current_epoch = 0
|
||||
self.epoch_losses = np.zeros(
|
||||
(self.total_epoch, 2), dtype=np.float32
|
||||
) # [epoch, train/val]
|
||||
self.epoch_losses = self.epoch_losses - 1
|
||||
self.epoch_accuracy = np.zeros(
|
||||
(self.total_epoch, 2), dtype=np.float32
|
||||
) # [epoch, train/val]
|
||||
self.epoch_accuracy = self.epoch_accuracy
|
||||
|
||||
def update(self, idx, train_loss, train_acc, val_loss, val_acc):
|
||||
assert (
|
||||
idx >= 0 and idx < self.total_epoch
|
||||
), "total_epoch : {} , but update with the {} index".format(
|
||||
self.total_epoch, idx
|
||||
)
|
||||
self.epoch_losses[idx, 0] = train_loss
|
||||
self.epoch_losses[idx, 1] = val_loss
|
||||
self.epoch_accuracy[idx, 0] = train_acc
|
||||
self.epoch_accuracy[idx, 1] = val_acc
|
||||
self.current_epoch = idx + 1
|
||||
return self.max_accuracy(False) == self.epoch_accuracy[idx, 1]
|
||||
|
||||
def max_accuracy(self, istrain):
|
||||
if self.current_epoch <= 0:
|
||||
return 0
|
||||
if istrain:
|
||||
return self.epoch_accuracy[: self.current_epoch, 0].max()
|
||||
else:
|
||||
return self.epoch_accuracy[: self.current_epoch, 1].max()
|
||||
|
||||
def plot_curve(self, save_path):
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
title = "the accuracy/loss curve of train/val"
|
||||
dpi = 100
|
||||
width, height = 1600, 1000
|
||||
legend_fontsize = 10
|
||||
figsize = width / float(dpi), height / float(dpi)
|
||||
|
||||
fig = plt.figure(figsize=figsize)
|
||||
x_axis = np.array([i for i in range(self.total_epoch)]) # epochs
|
||||
y_axis = np.zeros(self.total_epoch)
|
||||
|
||||
plt.xlim(0, self.total_epoch)
|
||||
plt.ylim(0, 100)
|
||||
interval_y = 5
|
||||
interval_x = 5
|
||||
plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
|
||||
plt.yticks(np.arange(0, 100 + interval_y, interval_y))
|
||||
plt.grid()
|
||||
plt.title(title, fontsize=20)
|
||||
plt.xlabel("the training epoch", fontsize=16)
|
||||
plt.ylabel("accuracy", fontsize=16)
|
||||
|
||||
y_axis[:] = self.epoch_accuracy[:, 0]
|
||||
plt.plot(x_axis, y_axis, color="g", linestyle="-", label="train-accuracy", lw=2)
|
||||
plt.legend(loc=4, fontsize=legend_fontsize)
|
||||
|
||||
y_axis[:] = self.epoch_accuracy[:, 1]
|
||||
plt.plot(x_axis, y_axis, color="y", linestyle="-", label="valid-accuracy", lw=2)
|
||||
plt.legend(loc=4, fontsize=legend_fontsize)
|
||||
|
||||
y_axis[:] = self.epoch_losses[:, 0]
|
||||
plt.plot(
|
||||
x_axis, y_axis * 50, color="g", linestyle=":", label="train-loss-x50", lw=2
|
||||
)
|
||||
plt.legend(loc=4, fontsize=legend_fontsize)
|
||||
|
||||
y_axis[:] = self.epoch_losses[:, 1]
|
||||
plt.plot(
|
||||
x_axis, y_axis * 50, color="y", linestyle=":", label="valid-loss-x50", lw=2
|
||||
)
|
||||
plt.legend(loc=4, fontsize=legend_fontsize)
|
||||
|
||||
if save_path is not None:
|
||||
fig.savefig(save_path, dpi=dpi, bbox_inches="tight")
|
||||
print("---- save figure {} into {}".format(title, save_path))
|
||||
plt.close(fig)
|
||||
21
AutoDL-Projects/xautodl/log_utils/pickle_wrap.py
Normal file
21
AutoDL-Projects/xautodl/log_utils/pickle_wrap.py
Normal file
@@ -0,0 +1,21 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def pickle_save(obj, path):
|
||||
file_path = Path(path)
|
||||
file_dir = file_path.parent
|
||||
file_dir.mkdir(parents=True, exist_ok=True)
|
||||
with file_path.open("wb") as f:
|
||||
pickle.dump(obj, f)
|
||||
|
||||
|
||||
def pickle_load(path):
|
||||
if not Path(path).exists():
|
||||
raise ValueError("{:} does not exists".format(path))
|
||||
with Path(path).open("rb") as f:
|
||||
data = pickle.load(f)
|
||||
return data
|
||||
49
AutoDL-Projects/xautodl/log_utils/time_utils.py
Normal file
49
AutoDL-Projects/xautodl/log_utils/time_utils.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import time, sys
|
||||
import numpy as np
|
||||
|
||||
|
||||
def time_for_file():
|
||||
ISOTIMEFORMAT = "%d-%h-at-%H-%M-%S"
|
||||
return "{:}".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
|
||||
|
||||
|
||||
def time_string():
|
||||
ISOTIMEFORMAT = "%Y-%m-%d %X"
|
||||
string = "[{:}]".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
|
||||
return string
|
||||
|
||||
|
||||
def time_string_short():
|
||||
ISOTIMEFORMAT = "%Y%m%d"
|
||||
string = "{:}".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
|
||||
return string
|
||||
|
||||
|
||||
def time_print(string, is_print=True):
|
||||
if is_print:
|
||||
print("{} : {}".format(time_string(), string))
|
||||
|
||||
|
||||
def convert_secs2time(epoch_time, return_str=False):
|
||||
need_hour = int(epoch_time / 3600)
|
||||
need_mins = int((epoch_time - 3600 * need_hour) / 60)
|
||||
need_secs = int(epoch_time - 3600 * need_hour - 60 * need_mins)
|
||||
if return_str:
|
||||
str = "[{:02d}:{:02d}:{:02d}]".format(need_hour, need_mins, need_secs)
|
||||
return str
|
||||
else:
|
||||
return need_hour, need_mins, need_secs
|
||||
|
||||
|
||||
def print_log(print_string, log):
|
||||
# if isinstance(log, Logger): log.log('{:}'.format(print_string))
|
||||
if hasattr(log, "log"):
|
||||
log.log("{:}".format(print_string))
|
||||
else:
|
||||
print("{:}".format(print_string))
|
||||
if log is not None:
|
||||
log.write("{:}\n".format(print_string))
|
||||
log.flush()
|
||||
117
AutoDL-Projects/xautodl/models/CifarDenseNet.py
Normal file
117
AutoDL-Projects/xautodl/models/CifarDenseNet.py
Normal file
@@ -0,0 +1,117 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import math, torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .initialization import initialize_resnet
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
def __init__(self, nChannels, growthRate):
|
||||
super(Bottleneck, self).__init__()
|
||||
interChannels = 4 * growthRate
|
||||
self.bn1 = nn.BatchNorm2d(nChannels)
|
||||
self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(interChannels)
|
||||
self.conv2 = nn.Conv2d(
|
||||
interChannels, growthRate, kernel_size=3, padding=1, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(F.relu(self.bn1(x)))
|
||||
out = self.conv2(F.relu(self.bn2(out)))
|
||||
out = torch.cat((x, out), 1)
|
||||
return out
|
||||
|
||||
|
||||
class SingleLayer(nn.Module):
|
||||
def __init__(self, nChannels, growthRate):
|
||||
super(SingleLayer, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2d(nChannels)
|
||||
self.conv1 = nn.Conv2d(
|
||||
nChannels, growthRate, kernel_size=3, padding=1, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(F.relu(self.bn1(x)))
|
||||
out = torch.cat((x, out), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Transition(nn.Module):
|
||||
def __init__(self, nChannels, nOutChannels):
|
||||
super(Transition, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2d(nChannels)
|
||||
self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(F.relu(self.bn1(x)))
|
||||
out = F.avg_pool2d(out, 2)
|
||||
return out
|
||||
|
||||
|
||||
class DenseNet(nn.Module):
|
||||
def __init__(self, growthRate, depth, reduction, nClasses, bottleneck):
|
||||
super(DenseNet, self).__init__()
|
||||
|
||||
if bottleneck:
|
||||
nDenseBlocks = int((depth - 4) / 6)
|
||||
else:
|
||||
nDenseBlocks = int((depth - 4) / 3)
|
||||
|
||||
self.message = "CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}".format(
|
||||
"bottleneck" if bottleneck else "basic",
|
||||
depth,
|
||||
reduction,
|
||||
growthRate,
|
||||
nClasses,
|
||||
)
|
||||
|
||||
nChannels = 2 * growthRate
|
||||
self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
|
||||
nChannels += nDenseBlocks * growthRate
|
||||
nOutChannels = int(math.floor(nChannels * reduction))
|
||||
self.trans1 = Transition(nChannels, nOutChannels)
|
||||
|
||||
nChannels = nOutChannels
|
||||
self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
|
||||
nChannels += nDenseBlocks * growthRate
|
||||
nOutChannels = int(math.floor(nChannels * reduction))
|
||||
self.trans2 = Transition(nChannels, nOutChannels)
|
||||
|
||||
nChannels = nOutChannels
|
||||
self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
|
||||
nChannels += nDenseBlocks * growthRate
|
||||
|
||||
self.act = nn.Sequential(
|
||||
nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), nn.AvgPool2d(8)
|
||||
)
|
||||
self.fc = nn.Linear(nChannels, nClasses)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck):
|
||||
layers = []
|
||||
for i in range(int(nDenseBlocks)):
|
||||
if bottleneck:
|
||||
layers.append(Bottleneck(nChannels, growthRate))
|
||||
else:
|
||||
layers.append(SingleLayer(nChannels, growthRate))
|
||||
nChannels += growthRate
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, inputs):
|
||||
out = self.conv1(inputs)
|
||||
out = self.trans1(self.dense1(out))
|
||||
out = self.trans2(self.dense2(out))
|
||||
out = self.dense3(out)
|
||||
features = self.act(out)
|
||||
features = features.view(features.size(0), -1)
|
||||
out = self.fc(features)
|
||||
return features, out
|
||||
180
AutoDL-Projects/xautodl/models/CifarResNet.py
Normal file
180
AutoDL-Projects/xautodl/models/CifarResNet.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .initialization import initialize_resnet
|
||||
from .SharedUtils import additive_func
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, nIn, nOut, stride):
|
||||
super(Downsample, self).__init__()
|
||||
assert stride == 2 and nOut == 2 * nIn, "stride:{} IO:{},{}".format(
|
||||
stride, nIn, nOut
|
||||
)
|
||||
self.in_dim = nIn
|
||||
self.out_dim = nOut
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.avg(x)
|
||||
out = self.conv(x)
|
||||
return out
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(self, nIn, nOut, kernel, stride, padding, bias, relu):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, bias=bias
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(nOut)
|
||||
if relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
self.out_dim = nOut
|
||||
self.num_conv = 1
|
||||
|
||||
def forward(self, x):
|
||||
conv = self.conv(x)
|
||||
bn = self.bn(conv)
|
||||
if self.relu:
|
||||
return self.relu(bn)
|
||||
else:
|
||||
return bn
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, True)
|
||||
self.conv_b = ConvBNReLU(planes, planes, 3, 1, 1, False, False)
|
||||
if stride == 2:
|
||||
self.downsample = Downsample(inplanes, planes, stride)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, False)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes
|
||||
self.num_conv = 2
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, basicblock)
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, True)
|
||||
self.conv_3x3 = ConvBNReLU(planes, planes, 3, stride, 1, False, True)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
planes, planes * self.expansion, 1, 1, 0, False, False
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = Downsample(inplanes, planes * self.expansion, stride)
|
||||
elif inplanes != planes * self.expansion:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes, planes * self.expansion, 1, 1, 0, False, False
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes * self.expansion
|
||||
self.num_conv = 3
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, bottleneck)
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class CifarResNet(nn.Module):
|
||||
def __init__(self, block_name, depth, num_classes, zero_init_residual):
|
||||
super(CifarResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "ResNetBasicblock":
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == "ResNetBottleneck":
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, "depth should be one of 164"
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
|
||||
self.message = "CifarResNet : Block : {:}, Depth : {:}, Layers for each block : {:}".format(
|
||||
block_name, depth, layer_blocks
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.channels = [16]
|
||||
self.layers = nn.ModuleList([ConvBNReLU(3, 16, 3, 1, 1, False, True)])
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
iC = self.channels[-1]
|
||||
planes = 16 * (2 ** stage)
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iC, planes, stride)
|
||||
self.channels.append(module.out_dim)
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iC,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(module.out_dim, num_classes)
|
||||
assert (
|
||||
sum(x.num_conv for x in self.layers) + 1 == depth
|
||||
), "invalid depth check {:} vs {:}".format(
|
||||
sum(x.num_conv for x in self.layers) + 1, depth
|
||||
)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
115
AutoDL-Projects/xautodl/models/CifarWideResNet.py
Normal file
115
AutoDL-Projects/xautodl/models/CifarWideResNet.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .initialization import initialize_resnet
|
||||
|
||||
|
||||
class WideBasicblock(nn.Module):
|
||||
def __init__(self, inplanes, planes, stride, dropout=False):
|
||||
super(WideBasicblock, self).__init__()
|
||||
|
||||
self.bn_a = nn.BatchNorm2d(inplanes)
|
||||
self.conv_a = nn.Conv2d(
|
||||
inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
||||
)
|
||||
|
||||
self.bn_b = nn.BatchNorm2d(planes)
|
||||
if dropout:
|
||||
self.dropout = nn.Dropout2d(p=0.5, inplace=True)
|
||||
else:
|
||||
self.dropout = None
|
||||
self.conv_b = nn.Conv2d(
|
||||
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
|
||||
)
|
||||
|
||||
if inplanes != planes:
|
||||
self.downsample = nn.Conv2d(
|
||||
inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
basicblock = self.bn_a(x)
|
||||
basicblock = F.relu(basicblock)
|
||||
basicblock = self.conv_a(basicblock)
|
||||
|
||||
basicblock = self.bn_b(basicblock)
|
||||
basicblock = F.relu(basicblock)
|
||||
if self.dropout is not None:
|
||||
basicblock = self.dropout(basicblock)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return x + basicblock
|
||||
|
||||
|
||||
class CifarWideResNet(nn.Module):
|
||||
"""
|
||||
ResNet optimized for the Cifar dataset, as specified in
|
||||
https://arxiv.org/abs/1512.03385.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, depth, widen_factor, num_classes, dropout):
|
||||
super(CifarWideResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
assert (depth - 4) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 4) // 6
|
||||
print(
|
||||
"CifarPreResNet : Depth : {} , Layers for each block : {}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.dropout = dropout
|
||||
self.conv_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
|
||||
self.message = "Wide ResNet : depth={:}, widen_factor={:}, class={:}".format(
|
||||
depth, widen_factor, num_classes
|
||||
)
|
||||
self.inplanes = 16
|
||||
self.stage_1 = self._make_layer(
|
||||
WideBasicblock, 16 * widen_factor, layer_blocks, 1
|
||||
)
|
||||
self.stage_2 = self._make_layer(
|
||||
WideBasicblock, 32 * widen_factor, layer_blocks, 2
|
||||
)
|
||||
self.stage_3 = self._make_layer(
|
||||
WideBasicblock, 64 * widen_factor, layer_blocks, 2
|
||||
)
|
||||
self.lastact = nn.Sequential(
|
||||
nn.BatchNorm2d(64 * widen_factor), nn.ReLU(inplace=True)
|
||||
)
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(64 * widen_factor, num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride):
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, self.dropout))
|
||||
self.inplanes = planes
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, 1, self.dropout))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_3x3(x)
|
||||
x = self.stage_1(x)
|
||||
x = self.stage_2(x)
|
||||
x = self.stage_3(x)
|
||||
x = self.lastact(x)
|
||||
x = self.avgpool(x)
|
||||
features = x.view(x.size(0), -1)
|
||||
outs = self.classifier(features)
|
||||
return features, outs
|
||||
117
AutoDL-Projects/xautodl/models/ImageNet_MobileNetV2.py
Normal file
117
AutoDL-Projects/xautodl/models/ImageNet_MobileNetV2.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
|
||||
from torch import nn
|
||||
from .initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
padding = (kernel_size - 1) // 2
|
||||
self.conv = nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_planes)
|
||||
self.relu = nn.ReLU6(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
out = self.bn(out)
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expand_ratio):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(round(inp * expand_ratio))
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
# pw
|
||||
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
||||
layers.extend(
|
||||
[
|
||||
# dw
|
||||
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
]
|
||||
)
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(
|
||||
self, num_classes, width_mult, input_channel, last_channel, block_name, dropout
|
||||
):
|
||||
super(MobileNetV2, self).__init__()
|
||||
if block_name == "InvertedResidual":
|
||||
block = InvertedResidual
|
||||
else:
|
||||
raise ValueError("invalid block name : {:}".format(block_name))
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# building first layer
|
||||
input_channel = int(input_channel * width_mult)
|
||||
self.last_channel = int(last_channel * max(1.0, width_mult))
|
||||
features = [ConvBNReLU(3, input_channel, stride=2)]
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = int(c * width_mult)
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
features.append(
|
||||
block(input_channel, output_channel, stride, expand_ratio=t)
|
||||
)
|
||||
input_channel = output_channel
|
||||
# building last several layers
|
||||
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(self.last_channel, num_classes),
|
||||
)
|
||||
self.message = "MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}".format(
|
||||
width_mult, input_channel, last_channel, block_name, dropout
|
||||
)
|
||||
|
||||
# weight initialization
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
features = self.features(inputs)
|
||||
vectors = features.mean([2, 3])
|
||||
predicts = self.classifier(vectors)
|
||||
return features, predicts
|
||||
217
AutoDL-Projects/xautodl/models/ImageNet_ResNet.py
Normal file
217
AutoDL-Projects/xautodl/models/ImageNet_ResNet.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# Deep Residual Learning for Image Recognition, CVPR 2016
|
||||
import torch.nn as nn
|
||||
from .initialization import initialize_resnet
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1):
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(
|
||||
self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64
|
||||
):
|
||||
super(BasicBlock, self).__init__()
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(
|
||||
self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64
|
||||
):
|
||||
super(Bottleneck, self).__init__()
|
||||
width = int(planes * (base_width / 64.0)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = nn.BatchNorm2d(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups)
|
||||
self.bn2 = nn.BatchNorm2d(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
block_name,
|
||||
layers,
|
||||
deep_stem,
|
||||
num_classes,
|
||||
zero_init_residual,
|
||||
groups,
|
||||
width_per_group,
|
||||
):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
# planes = [int(width_per_group * groups * 2 ** i) for i in range(4)]
|
||||
if block_name == "BasicBlock":
|
||||
block = BasicBlock
|
||||
elif block_name == "Bottleneck":
|
||||
block = Bottleneck
|
||||
else:
|
||||
raise ValueError("invalid block-name : {:}".format(block_name))
|
||||
|
||||
if not deep_stem:
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.inplanes = 64
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(
|
||||
block, 64, layers[0], stride=1, groups=groups, base_width=width_per_group
|
||||
)
|
||||
self.layer2 = self._make_layer(
|
||||
block, 128, layers[1], stride=2, groups=groups, base_width=width_per_group
|
||||
)
|
||||
self.layer3 = self._make_layer(
|
||||
block, 256, layers[2], stride=2, groups=groups, base_width=width_per_group
|
||||
)
|
||||
self.layer4 = self._make_layer(
|
||||
block, 512, layers[3], stride=2, groups=groups, base_width=width_per_group
|
||||
)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
self.message = (
|
||||
"block = {:}, layers = {:}, deep_stem = {:}, num_classes = {:}".format(
|
||||
block, layers, deep_stem, num_classes
|
||||
)
|
||||
)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride, groups, base_width):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
if stride == 2:
|
||||
downsample = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
|
||||
conv1x1(self.inplanes, planes * block.expansion, 1),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
elif stride == 1:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid stride [{:}] for downsample".format(stride))
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
block(self.inplanes, planes, stride, downsample, groups, base_width)
|
||||
)
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, 1, None, groups, base_width))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.fc(features)
|
||||
|
||||
return features, logits
|
||||
37
AutoDL-Projects/xautodl/models/SharedUtils.py
Normal file
37
AutoDL-Projects/xautodl/models/SharedUtils.py
Normal file
@@ -0,0 +1,37 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def additive_func(A, B):
|
||||
assert A.dim() == B.dim() and A.size(0) == B.size(0), "{:} vs {:}".format(
|
||||
A.size(), B.size()
|
||||
)
|
||||
C = min(A.size(1), B.size(1))
|
||||
if A.size(1) == B.size(1):
|
||||
return A + B
|
||||
elif A.size(1) < B.size(1):
|
||||
out = B.clone()
|
||||
out[:, :C] += A
|
||||
return out
|
||||
else:
|
||||
out = A.clone()
|
||||
out[:, :C] += B
|
||||
return out
|
||||
|
||||
|
||||
def change_key(key, value):
|
||||
def func(m):
|
||||
if hasattr(m, key):
|
||||
setattr(m, key, value)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def parse_channel_info(xstring):
|
||||
blocks = xstring.split(" ")
|
||||
blocks = [x.split("-") for x in blocks]
|
||||
blocks = [[int(_) for _ in x] for x in blocks]
|
||||
return blocks
|
||||
326
AutoDL-Projects/xautodl/models/__init__.py
Normal file
326
AutoDL-Projects/xautodl/models/__init__.py
Normal file
@@ -0,0 +1,326 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from os import path as osp
|
||||
from typing import List, Text
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"change_key",
|
||||
"get_cell_based_tiny_net",
|
||||
"get_search_spaces",
|
||||
"get_cifar_models",
|
||||
"get_imagenet_models",
|
||||
"obtain_model",
|
||||
"obtain_search_model",
|
||||
"load_net_from_checkpoint",
|
||||
"CellStructure",
|
||||
"CellArchitectures",
|
||||
]
|
||||
|
||||
# useful modules
|
||||
from xautodl.config_utils import dict2config
|
||||
from .SharedUtils import change_key
|
||||
from .cell_searchs import CellStructure, CellArchitectures
|
||||
|
||||
|
||||
# Cell-based NAS Models
|
||||
def get_cell_based_tiny_net(config):
|
||||
if isinstance(config, dict):
|
||||
config = dict2config(config, None) # to support the argument being a dict
|
||||
super_type = getattr(config, "super_type", "basic")
|
||||
group_names = ["DARTS-V1", "DARTS-V2", "GDAS", "SETN", "ENAS", "RANDOM", "generic"]
|
||||
if super_type == "basic" and config.name in group_names:
|
||||
from .cell_searchs import nas201_super_nets as nas_super_nets
|
||||
|
||||
try:
|
||||
return nas_super_nets[config.name](
|
||||
config.C,
|
||||
config.N,
|
||||
config.max_nodes,
|
||||
config.num_classes,
|
||||
config.space,
|
||||
config.affine,
|
||||
config.track_running_stats,
|
||||
)
|
||||
except:
|
||||
return nas_super_nets[config.name](
|
||||
config.C, config.N, config.max_nodes, config.num_classes, config.space
|
||||
)
|
||||
elif super_type == "search-shape":
|
||||
from .shape_searchs import GenericNAS301Model
|
||||
|
||||
genotype = CellStructure.str2structure(config.genotype)
|
||||
return GenericNAS301Model(
|
||||
config.candidate_Cs,
|
||||
config.max_num_Cs,
|
||||
genotype,
|
||||
config.num_classes,
|
||||
config.affine,
|
||||
config.track_running_stats,
|
||||
)
|
||||
elif super_type == "nasnet-super":
|
||||
from .cell_searchs import nasnet_super_nets as nas_super_nets
|
||||
|
||||
return nas_super_nets[config.name](
|
||||
config.C,
|
||||
config.N,
|
||||
config.steps,
|
||||
config.multiplier,
|
||||
config.stem_multiplier,
|
||||
config.num_classes,
|
||||
config.space,
|
||||
config.affine,
|
||||
config.track_running_stats,
|
||||
)
|
||||
elif config.name == "infer.tiny":
|
||||
from .cell_infers import TinyNetwork
|
||||
|
||||
if hasattr(config, "genotype"):
|
||||
genotype = config.genotype
|
||||
elif hasattr(config, "arch_str"):
|
||||
genotype = CellStructure.str2structure(config.arch_str)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Can not find genotype from this config : {:}".format(config)
|
||||
)
|
||||
return TinyNetwork(config.C, config.N, genotype, config.num_classes)
|
||||
elif config.name == "infer.shape.tiny":
|
||||
from .shape_infers import DynamicShapeTinyNet
|
||||
|
||||
if isinstance(config.channels, str):
|
||||
channels = tuple([int(x) for x in config.channels.split(":")])
|
||||
else:
|
||||
channels = config.channels
|
||||
genotype = CellStructure.str2structure(config.genotype)
|
||||
return DynamicShapeTinyNet(channels, genotype, config.num_classes)
|
||||
elif config.name == "infer.nasnet-cifar":
|
||||
from .cell_infers import NASNetonCIFAR
|
||||
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError("invalid network name : {:}".format(config.name))
|
||||
|
||||
|
||||
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
|
||||
def get_search_spaces(xtype, name) -> List[Text]:
|
||||
if xtype == "cell" or xtype == "tss": # The topology search space.
|
||||
from .cell_operations import SearchSpaceNames
|
||||
|
||||
assert name in SearchSpaceNames, "invalid name [{:}] in {:}".format(
|
||||
name, SearchSpaceNames.keys()
|
||||
)
|
||||
return SearchSpaceNames[name]
|
||||
elif xtype == "sss": # The size search space.
|
||||
if name in ["nats-bench", "nats-bench-size"]:
|
||||
return {"candidates": [8, 16, 24, 32, 40, 48, 56, 64], "numbers": 5}
|
||||
else:
|
||||
raise ValueError("Invalid name : {:}".format(name))
|
||||
else:
|
||||
raise ValueError("invalid search-space type is {:}".format(xtype))
|
||||
|
||||
|
||||
def get_cifar_models(config, extra_path=None):
|
||||
super_type = getattr(config, "super_type", "basic")
|
||||
if super_type == "basic":
|
||||
from .CifarResNet import CifarResNet
|
||||
from .CifarDenseNet import DenseNet
|
||||
from .CifarWideResNet import CifarWideResNet
|
||||
|
||||
if config.arch == "resnet":
|
||||
return CifarResNet(
|
||||
config.module, config.depth, config.class_num, config.zero_init_residual
|
||||
)
|
||||
elif config.arch == "densenet":
|
||||
return DenseNet(
|
||||
config.growthRate,
|
||||
config.depth,
|
||||
config.reduction,
|
||||
config.class_num,
|
||||
config.bottleneck,
|
||||
)
|
||||
elif config.arch == "wideresnet":
|
||||
return CifarWideResNet(
|
||||
config.depth, config.wide_factor, config.class_num, config.dropout
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid module type : {:}".format(config.arch))
|
||||
elif super_type.startswith("infer"):
|
||||
from .shape_infers import InferWidthCifarResNet
|
||||
from .shape_infers import InferDepthCifarResNet
|
||||
from .shape_infers import InferCifarResNet
|
||||
from .cell_infers import NASNetonCIFAR
|
||||
|
||||
assert len(super_type.split("-")) == 2, "invalid super_type : {:}".format(
|
||||
super_type
|
||||
)
|
||||
infer_mode = super_type.split("-")[1]
|
||||
if infer_mode == "width":
|
||||
return InferWidthCifarResNet(
|
||||
config.module,
|
||||
config.depth,
|
||||
config.xchannels,
|
||||
config.class_num,
|
||||
config.zero_init_residual,
|
||||
)
|
||||
elif infer_mode == "depth":
|
||||
return InferDepthCifarResNet(
|
||||
config.module,
|
||||
config.depth,
|
||||
config.xblocks,
|
||||
config.class_num,
|
||||
config.zero_init_residual,
|
||||
)
|
||||
elif infer_mode == "shape":
|
||||
return InferCifarResNet(
|
||||
config.module,
|
||||
config.depth,
|
||||
config.xblocks,
|
||||
config.xchannels,
|
||||
config.class_num,
|
||||
config.zero_init_residual,
|
||||
)
|
||||
elif infer_mode == "nasnet.cifar":
|
||||
genotype = config.genotype
|
||||
if extra_path is not None: # reload genotype by extra_path
|
||||
if not osp.isfile(extra_path):
|
||||
raise ValueError("invalid extra_path : {:}".format(extra_path))
|
||||
xdata = torch.load(extra_path)
|
||||
current_epoch = xdata["epoch"]
|
||||
genotype = xdata["genotypes"][current_epoch - 1]
|
||||
C = config.C if hasattr(config, "C") else config.ichannel
|
||||
N = config.N if hasattr(config, "N") else config.layers
|
||||
return NASNetonCIFAR(
|
||||
C, N, config.stem_multi, config.class_num, genotype, config.auxiliary
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid infer-mode : {:}".format(infer_mode))
|
||||
else:
|
||||
raise ValueError("invalid super-type : {:}".format(super_type))
|
||||
|
||||
|
||||
def get_imagenet_models(config):
|
||||
super_type = getattr(config, "super_type", "basic")
|
||||
if super_type == "basic":
|
||||
from .ImageNet_ResNet import ResNet
|
||||
from .ImageNet_MobileNetV2 import MobileNetV2
|
||||
|
||||
if config.arch == "resnet":
|
||||
return ResNet(
|
||||
config.block_name,
|
||||
config.layers,
|
||||
config.deep_stem,
|
||||
config.class_num,
|
||||
config.zero_init_residual,
|
||||
config.groups,
|
||||
config.width_per_group,
|
||||
)
|
||||
elif config.arch == "mobilenet_v2":
|
||||
return MobileNetV2(
|
||||
config.class_num,
|
||||
config.width_multi,
|
||||
config.input_channel,
|
||||
config.last_channel,
|
||||
"InvertedResidual",
|
||||
config.dropout,
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid arch : {:}".format(config.arch))
|
||||
elif super_type.startswith("infer"): # NAS searched architecture
|
||||
assert len(super_type.split("-")) == 2, "invalid super_type : {:}".format(
|
||||
super_type
|
||||
)
|
||||
infer_mode = super_type.split("-")[1]
|
||||
if infer_mode == "shape":
|
||||
from .shape_infers import InferImagenetResNet
|
||||
from .shape_infers import InferMobileNetV2
|
||||
|
||||
if config.arch == "resnet":
|
||||
return InferImagenetResNet(
|
||||
config.block_name,
|
||||
config.layers,
|
||||
config.xblocks,
|
||||
config.xchannels,
|
||||
config.deep_stem,
|
||||
config.class_num,
|
||||
config.zero_init_residual,
|
||||
)
|
||||
elif config.arch == "MobileNetV2":
|
||||
return InferMobileNetV2(
|
||||
config.class_num, config.xchannels, config.xblocks, config.dropout
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid arch-mode : {:}".format(config.arch))
|
||||
else:
|
||||
raise ValueError("invalid infer-mode : {:}".format(infer_mode))
|
||||
else:
|
||||
raise ValueError("invalid super-type : {:}".format(super_type))
|
||||
|
||||
|
||||
# Try to obtain the network by config.
|
||||
def obtain_model(config, extra_path=None):
|
||||
if config.dataset == "cifar":
|
||||
return get_cifar_models(config, extra_path)
|
||||
elif config.dataset == "imagenet":
|
||||
return get_imagenet_models(config)
|
||||
else:
|
||||
raise ValueError("invalid dataset in the model config : {:}".format(config))
|
||||
|
||||
|
||||
def obtain_search_model(config):
|
||||
if config.dataset == "cifar":
|
||||
if config.arch == "resnet":
|
||||
from .shape_searchs import SearchWidthCifarResNet
|
||||
from .shape_searchs import SearchDepthCifarResNet
|
||||
from .shape_searchs import SearchShapeCifarResNet
|
||||
|
||||
if config.search_mode == "width":
|
||||
return SearchWidthCifarResNet(
|
||||
config.module, config.depth, config.class_num
|
||||
)
|
||||
elif config.search_mode == "depth":
|
||||
return SearchDepthCifarResNet(
|
||||
config.module, config.depth, config.class_num
|
||||
)
|
||||
elif config.search_mode == "shape":
|
||||
return SearchShapeCifarResNet(
|
||||
config.module, config.depth, config.class_num
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid search mode : {:}".format(config.search_mode))
|
||||
elif config.arch == "simres":
|
||||
from .shape_searchs import SearchWidthSimResNet
|
||||
|
||||
if config.search_mode == "width":
|
||||
return SearchWidthSimResNet(config.depth, config.class_num)
|
||||
else:
|
||||
raise ValueError("invalid search mode : {:}".format(config.search_mode))
|
||||
else:
|
||||
raise ValueError(
|
||||
"invalid arch : {:} for dataset [{:}]".format(
|
||||
config.arch, config.dataset
|
||||
)
|
||||
)
|
||||
elif config.dataset == "imagenet":
|
||||
from .shape_searchs import SearchShapeImagenetResNet
|
||||
|
||||
assert config.search_mode == "shape", "invalid search-mode : {:}".format(
|
||||
config.search_mode
|
||||
)
|
||||
if config.arch == "resnet":
|
||||
return SearchShapeImagenetResNet(
|
||||
config.block_name, config.layers, config.deep_stem, config.class_num
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid model config : {:}".format(config))
|
||||
else:
|
||||
raise ValueError("invalid dataset in the model config : {:}".format(config))
|
||||
|
||||
|
||||
def load_net_from_checkpoint(checkpoint):
|
||||
assert osp.isfile(checkpoint), "checkpoint {:} does not exist".format(checkpoint)
|
||||
checkpoint = torch.load(checkpoint)
|
||||
model_config = dict2config(checkpoint["model-config"], None)
|
||||
model = obtain_model(model_config)
|
||||
model.load_state_dict(checkpoint["base-model"])
|
||||
return model
|
||||
5
AutoDL-Projects/xautodl/models/cell_infers/__init__.py
Normal file
5
AutoDL-Projects/xautodl/models/cell_infers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
from .tiny_network import TinyNetwork
|
||||
from .nasnet_cifar import NASNetonCIFAR
|
||||
155
AutoDL-Projects/xautodl/models/cell_infers/cells.py
Normal file
155
AutoDL-Projects/xautodl/models/cell_infers/cells.py
Normal file
@@ -0,0 +1,155 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
|
||||
from xautodl.models.cell_operations import OPS
|
||||
|
||||
|
||||
# Cell for NAS-Bench-201
|
||||
class InferCell(nn.Module):
|
||||
def __init__(
|
||||
self, genotype, C_in, C_out, stride, affine=True, track_running_stats=True
|
||||
):
|
||||
super(InferCell, self).__init__()
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
self.node_IN = []
|
||||
self.node_IX = []
|
||||
self.genotype = deepcopy(genotype)
|
||||
for i in range(1, len(genotype)):
|
||||
node_info = genotype[i - 1]
|
||||
cur_index = []
|
||||
cur_innod = []
|
||||
for (op_name, op_in) in node_info:
|
||||
if op_in == 0:
|
||||
layer = OPS[op_name](
|
||||
C_in, C_out, stride, affine, track_running_stats
|
||||
)
|
||||
else:
|
||||
layer = OPS[op_name](C_out, C_out, 1, affine, track_running_stats)
|
||||
cur_index.append(len(self.layers))
|
||||
cur_innod.append(op_in)
|
||||
self.layers.append(layer)
|
||||
self.node_IX.append(cur_index)
|
||||
self.node_IN.append(cur_innod)
|
||||
self.nodes = len(genotype)
|
||||
self.in_dim = C_in
|
||||
self.out_dim = C_out
|
||||
|
||||
def extra_repr(self):
|
||||
string = "info :: nodes={nodes}, inC={in_dim}, outC={out_dim}".format(
|
||||
**self.__dict__
|
||||
)
|
||||
laystr = []
|
||||
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)):
|
||||
y = [
|
||||
"I{:}-L{:}".format(_ii, _il)
|
||||
for _il, _ii in zip(node_layers, node_innods)
|
||||
]
|
||||
x = "{:}<-({:})".format(i + 1, ",".join(y))
|
||||
laystr.append(x)
|
||||
return (
|
||||
string
|
||||
+ ", [{:}]".format(" | ".join(laystr))
|
||||
+ ", {:}".format(self.genotype.tostr())
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
nodes = [inputs]
|
||||
for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)):
|
||||
node_feature = sum(
|
||||
self.layers[_il](nodes[_ii])
|
||||
for _il, _ii in zip(node_layers, node_innods)
|
||||
)
|
||||
nodes.append(node_feature)
|
||||
return nodes[-1]
|
||||
|
||||
|
||||
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||
class NASNetInferCell(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
genotype,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
):
|
||||
super(NASNetInferCell, self).__init__()
|
||||
self.reduction = reduction
|
||||
if reduction_prev:
|
||||
self.preprocess0 = OPS["skip_connect"](
|
||||
C_prev_prev, C, 2, affine, track_running_stats
|
||||
)
|
||||
else:
|
||||
self.preprocess0 = OPS["nor_conv_1x1"](
|
||||
C_prev_prev, C, 1, affine, track_running_stats
|
||||
)
|
||||
self.preprocess1 = OPS["nor_conv_1x1"](
|
||||
C_prev, C, 1, affine, track_running_stats
|
||||
)
|
||||
|
||||
if not reduction:
|
||||
nodes, concats = genotype["normal"], genotype["normal_concat"]
|
||||
else:
|
||||
nodes, concats = genotype["reduce"], genotype["reduce_concat"]
|
||||
self._multiplier = len(concats)
|
||||
self._concats = concats
|
||||
self._steps = len(nodes)
|
||||
self._nodes = nodes
|
||||
self.edges = nn.ModuleDict()
|
||||
for i, node in enumerate(nodes):
|
||||
for in_node in node:
|
||||
name, j = in_node[0], in_node[1]
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
node_str = "{:}<-{:}".format(i + 2, j)
|
||||
self.edges[node_str] = OPS[name](
|
||||
C, C, stride, affine, track_running_stats
|
||||
)
|
||||
|
||||
# [TODO] to support drop_prob in this function..
|
||||
def forward(self, s0, s1, unused_drop_prob):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i, node in enumerate(self._nodes):
|
||||
clist = []
|
||||
for in_node in node:
|
||||
name, j = in_node[0], in_node[1]
|
||||
node_str = "{:}<-{:}".format(i + 2, j)
|
||||
op = self.edges[node_str]
|
||||
clist.append(op(states[j]))
|
||||
states.append(sum(clist))
|
||||
return torch.cat([states[x] for x in self._concats], dim=1)
|
||||
|
||||
|
||||
class AuxiliaryHeadCIFAR(nn.Module):
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 8x8"""
|
||||
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(
|
||||
5, stride=3, padding=0, count_include_pad=False
|
||||
), # image size = 2 x 2
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
return x
|
||||
118
AutoDL-Projects/xautodl/models/cell_infers/nasnet_cifar.py
Normal file
118
AutoDL-Projects/xautodl/models/cell_infers/nasnet_cifar.py
Normal file
@@ -0,0 +1,118 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
|
||||
from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR
|
||||
|
||||
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetonCIFAR(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C,
|
||||
N,
|
||||
stem_multiplier,
|
||||
num_classes,
|
||||
genotype,
|
||||
auxiliary,
|
||||
affine=True,
|
||||
track_running_stats=True,
|
||||
):
|
||||
super(NASNetonCIFAR, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C * stem_multiplier),
|
||||
)
|
||||
|
||||
# config for each layer
|
||||
layer_channels = (
|
||||
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
|
||||
)
|
||||
layer_reductions = (
|
||||
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr, reduction_prev = (
|
||||
C * stem_multiplier,
|
||||
C * stem_multiplier,
|
||||
C,
|
||||
False,
|
||||
)
|
||||
self.auxiliary_index = None
|
||||
self.auxiliary_head = None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
cell = InferCell(
|
||||
genotype,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C_curr,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
self.cells.append(cell)
|
||||
C_prev_prev, C_prev, reduction_prev = (
|
||||
C_prev,
|
||||
cell._multiplier * C_curr,
|
||||
reduction,
|
||||
)
|
||||
if reduction and C_curr == C * 4 and auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes)
|
||||
self.auxiliary_index = index
|
||||
self._Layer = len(self.cells)
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.drop_path_prob = -1
|
||||
|
||||
def update_drop_path(self, drop_path_prob):
|
||||
self.drop_path_prob = drop_path_prob
|
||||
|
||||
def auxiliary_param(self):
|
||||
if self.auxiliary_head is None:
|
||||
return []
|
||||
else:
|
||||
return list(self.auxiliary_head.parameters())
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
stem_feature, logits_aux = self.stem(inputs), None
|
||||
cell_results = [stem_feature, stem_feature]
|
||||
for i, cell in enumerate(self.cells):
|
||||
cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob)
|
||||
cell_results.append(cell_feature)
|
||||
if (
|
||||
self.auxiliary_index is not None
|
||||
and i == self.auxiliary_index
|
||||
and self.training
|
||||
):
|
||||
logits_aux = self.auxiliary_head(cell_results[-1])
|
||||
out = self.lastact(cell_results[-1])
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
if logits_aux is None:
|
||||
return out, logits
|
||||
else:
|
||||
return out, [logits, logits_aux]
|
||||
63
AutoDL-Projects/xautodl/models/cell_infers/tiny_network.py
Normal file
63
AutoDL-Projects/xautodl/models/cell_infers/tiny_network.py
Normal file
@@ -0,0 +1,63 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .cells import InferCell
|
||||
|
||||
|
||||
# The macro structure for architectures in NAS-Bench-201
|
||||
class TinyNetwork(nn.Module):
|
||||
def __init__(self, C, N, genotype, num_classes):
|
||||
super(TinyNetwork, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev = C
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2, True)
|
||||
else:
|
||||
cell = InferCell(genotype, C_prev, C_curr, 1)
|
||||
self.cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self._Layer = len(self.cells)
|
||||
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
553
AutoDL-Projects/xautodl/models/cell_operations.py
Normal file
553
AutoDL-Projects/xautodl/models/cell_operations.py
Normal file
@@ -0,0 +1,553 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ["OPS", "RAW_OP_CLASSES", "ResNetBasicblock", "SearchSpaceNames"]
|
||||
|
||||
OPS = {
|
||||
"none": lambda C_in, C_out, stride, affine, track_running_stats: Zero(
|
||||
C_in, C_out, stride
|
||||
),
|
||||
"avg_pool_3x3": lambda C_in, C_out, stride, affine, track_running_stats: POOLING(
|
||||
C_in, C_out, stride, "avg", affine, track_running_stats
|
||||
),
|
||||
"max_pool_3x3": lambda C_in, C_out, stride, affine, track_running_stats: POOLING(
|
||||
C_in, C_out, stride, "max", affine, track_running_stats
|
||||
),
|
||||
"nor_conv_7x7": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(
|
||||
C_in,
|
||||
C_out,
|
||||
(7, 7),
|
||||
(stride, stride),
|
||||
(3, 3),
|
||||
(1, 1),
|
||||
affine,
|
||||
track_running_stats,
|
||||
),
|
||||
"nor_conv_3x3": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(
|
||||
C_in,
|
||||
C_out,
|
||||
(3, 3),
|
||||
(stride, stride),
|
||||
(1, 1),
|
||||
(1, 1),
|
||||
affine,
|
||||
track_running_stats,
|
||||
),
|
||||
"nor_conv_1x1": lambda C_in, C_out, stride, affine, track_running_stats: ReLUConvBN(
|
||||
C_in,
|
||||
C_out,
|
||||
(1, 1),
|
||||
(stride, stride),
|
||||
(0, 0),
|
||||
(1, 1),
|
||||
affine,
|
||||
track_running_stats,
|
||||
),
|
||||
"dua_sepc_3x3": lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(
|
||||
C_in,
|
||||
C_out,
|
||||
(3, 3),
|
||||
(stride, stride),
|
||||
(1, 1),
|
||||
(1, 1),
|
||||
affine,
|
||||
track_running_stats,
|
||||
),
|
||||
"dua_sepc_5x5": lambda C_in, C_out, stride, affine, track_running_stats: DualSepConv(
|
||||
C_in,
|
||||
C_out,
|
||||
(5, 5),
|
||||
(stride, stride),
|
||||
(2, 2),
|
||||
(1, 1),
|
||||
affine,
|
||||
track_running_stats,
|
||||
),
|
||||
"dil_sepc_3x3": lambda C_in, C_out, stride, affine, track_running_stats: SepConv(
|
||||
C_in,
|
||||
C_out,
|
||||
(3, 3),
|
||||
(stride, stride),
|
||||
(2, 2),
|
||||
(2, 2),
|
||||
affine,
|
||||
track_running_stats,
|
||||
),
|
||||
"dil_sepc_5x5": lambda C_in, C_out, stride, affine, track_running_stats: SepConv(
|
||||
C_in,
|
||||
C_out,
|
||||
(5, 5),
|
||||
(stride, stride),
|
||||
(4, 4),
|
||||
(2, 2),
|
||||
affine,
|
||||
track_running_stats,
|
||||
),
|
||||
"skip_connect": lambda C_in, C_out, stride, affine, track_running_stats: Identity()
|
||||
if stride == 1 and C_in == C_out
|
||||
else FactorizedReduce(C_in, C_out, stride, affine, track_running_stats),
|
||||
}
|
||||
|
||||
CONNECT_NAS_BENCHMARK = ["none", "skip_connect", "nor_conv_3x3"]
|
||||
NAS_BENCH_201 = ["none", "skip_connect", "nor_conv_1x1", "nor_conv_3x3", "avg_pool_3x3"]
|
||||
DARTS_SPACE = [
|
||||
"none",
|
||||
"skip_connect",
|
||||
"dua_sepc_3x3",
|
||||
"dua_sepc_5x5",
|
||||
"dil_sepc_3x3",
|
||||
"dil_sepc_5x5",
|
||||
"avg_pool_3x3",
|
||||
"max_pool_3x3",
|
||||
]
|
||||
|
||||
SearchSpaceNames = {
|
||||
"connect-nas": CONNECT_NAS_BENCHMARK,
|
||||
"nats-bench": NAS_BENCH_201,
|
||||
"nas-bench-201": NAS_BENCH_201,
|
||||
"darts": DARTS_SPACE,
|
||||
}
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C_in,
|
||||
C_out,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
affine,
|
||||
track_running_stats=True,
|
||||
):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in,
|
||||
C_out,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=not affine,
|
||||
),
|
||||
nn.BatchNorm2d(
|
||||
C_out, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C_in,
|
||||
C_out,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
affine,
|
||||
track_running_stats=True,
|
||||
):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C_in,
|
||||
C_in,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=C_in,
|
||||
bias=False,
|
||||
),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=not affine),
|
||||
nn.BatchNorm2d(
|
||||
C_out, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DualSepConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C_in,
|
||||
C_out,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
affine,
|
||||
track_running_stats=True,
|
||||
):
|
||||
super(DualSepConv, self).__init__()
|
||||
self.op_a = SepConv(
|
||||
C_in,
|
||||
C_in,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
self.op_b = SepConv(
|
||||
C_in, C_out, kernel_size, 1, padding, dilation, affine, track_running_stats
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.op_a(x)
|
||||
x = self.op_b(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
def __init__(self, inplanes, planes, stride, affine=True, track_running_stats=True):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_a = ReLUConvBN(
|
||||
inplanes, planes, 3, stride, 1, 1, affine, track_running_stats
|
||||
)
|
||||
self.conv_b = ReLUConvBN(
|
||||
planes, planes, 3, 1, 1, 1, affine, track_running_stats
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
|
||||
nn.Conv2d(
|
||||
inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False
|
||||
),
|
||||
)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ReLUConvBN(
|
||||
inplanes, planes, 1, 1, 0, 1, affine, track_running_stats
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.in_dim = inplanes
|
||||
self.out_dim = planes
|
||||
self.stride = stride
|
||||
self.num_conv = 2
|
||||
|
||||
def extra_repr(self):
|
||||
string = "{name}(inC={in_dim}, outC={out_dim}, stride={stride})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
return string
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
return residual + basicblock
|
||||
|
||||
|
||||
class POOLING(nn.Module):
|
||||
def __init__(
|
||||
self, C_in, C_out, stride, mode, affine=True, track_running_stats=True
|
||||
):
|
||||
super(POOLING, self).__init__()
|
||||
if C_in == C_out:
|
||||
self.preprocess = None
|
||||
else:
|
||||
self.preprocess = ReLUConvBN(
|
||||
C_in, C_out, 1, 1, 0, 1, affine, track_running_stats
|
||||
)
|
||||
if mode == "avg":
|
||||
self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
|
||||
elif mode == "max":
|
||||
self.op = nn.MaxPool2d(3, stride=stride, padding=1)
|
||||
else:
|
||||
raise ValueError("Invalid mode={:} in POOLING".format(mode))
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.preprocess:
|
||||
x = self.preprocess(inputs)
|
||||
else:
|
||||
x = inputs
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Zero(nn.Module):
|
||||
def __init__(self, C_in, C_out, stride):
|
||||
super(Zero, self).__init__()
|
||||
self.C_in = C_in
|
||||
self.C_out = C_out
|
||||
self.stride = stride
|
||||
self.is_zero = True
|
||||
|
||||
def forward(self, x):
|
||||
if self.C_in == self.C_out:
|
||||
if self.stride == 1:
|
||||
return x.mul(0.0)
|
||||
else:
|
||||
return x[:, :, :: self.stride, :: self.stride].mul(0.0)
|
||||
else:
|
||||
shape = list(x.shape)
|
||||
shape[1] = self.C_out
|
||||
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
|
||||
return zeros
|
||||
|
||||
def extra_repr(self):
|
||||
return "C_in={C_in}, C_out={C_out}, stride={stride}".format(**self.__dict__)
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
def __init__(self, C_in, C_out, stride, affine, track_running_stats):
|
||||
super(FactorizedReduce, self).__init__()
|
||||
self.stride = stride
|
||||
self.C_in = C_in
|
||||
self.C_out = C_out
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
if stride == 2:
|
||||
# assert C_out % 2 == 0, 'C_out : {:}'.format(C_out)
|
||||
C_outs = [C_out // 2, C_out - C_out // 2]
|
||||
self.convs = nn.ModuleList()
|
||||
for i in range(2):
|
||||
self.convs.append(
|
||||
nn.Conv2d(
|
||||
C_in, C_outs[i], 1, stride=stride, padding=0, bias=not affine
|
||||
)
|
||||
)
|
||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||
elif stride == 1:
|
||||
self.conv = nn.Conv2d(
|
||||
C_in, C_out, 1, stride=stride, padding=0, bias=not affine
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid stride : {:}".format(stride))
|
||||
self.bn = nn.BatchNorm2d(
|
||||
C_out, affine=affine, track_running_stats=track_running_stats
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 2:
|
||||
x = self.relu(x)
|
||||
y = self.pad(x)
|
||||
out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1)
|
||||
else:
|
||||
out = self.conv(x)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
def extra_repr(self):
|
||||
return "C_in={C_in}, C_out={C_out}, stride={stride}".format(**self.__dict__)
|
||||
|
||||
|
||||
# Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019
|
||||
class PartAwareOp(nn.Module):
|
||||
def __init__(self, C_in, C_out, stride, part=4):
|
||||
super().__init__()
|
||||
self.part = 4
|
||||
self.hidden = C_in // 3
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.local_conv_list = nn.ModuleList()
|
||||
for i in range(self.part):
|
||||
self.local_conv_list.append(
|
||||
nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(C_in, self.hidden, 1),
|
||||
nn.BatchNorm2d(self.hidden, affine=True),
|
||||
)
|
||||
)
|
||||
self.W_K = nn.Linear(self.hidden, self.hidden)
|
||||
self.W_Q = nn.Linear(self.hidden, self.hidden)
|
||||
|
||||
if stride == 2:
|
||||
self.last = FactorizedReduce(C_in + self.hidden, C_out, 2)
|
||||
elif stride == 1:
|
||||
self.last = FactorizedReduce(C_in + self.hidden, C_out, 1)
|
||||
else:
|
||||
raise ValueError("Invalid Stride : {:}".format(stride))
|
||||
|
||||
def forward(self, x):
|
||||
batch, C, H, W = x.size()
|
||||
assert H >= self.part, "input size too small : {:} vs {:}".format(
|
||||
x.shape, self.part
|
||||
)
|
||||
IHs = [0]
|
||||
for i in range(self.part):
|
||||
IHs.append(min(H, int((i + 1) * (float(H) / self.part))))
|
||||
local_feat_list = []
|
||||
for i in range(self.part):
|
||||
feature = x[:, :, IHs[i] : IHs[i + 1], :]
|
||||
xfeax = self.avg_pool(feature)
|
||||
xfea = self.local_conv_list[i](xfeax)
|
||||
local_feat_list.append(xfea)
|
||||
part_feature = torch.cat(local_feat_list, dim=2).view(batch, -1, self.part)
|
||||
part_feature = part_feature.transpose(1, 2).contiguous()
|
||||
part_K = self.W_K(part_feature)
|
||||
part_Q = self.W_Q(part_feature).transpose(1, 2).contiguous()
|
||||
weight_att = torch.bmm(part_K, part_Q)
|
||||
attention = torch.softmax(weight_att, dim=2)
|
||||
aggreateF = torch.bmm(attention, part_feature).transpose(1, 2).contiguous()
|
||||
features = []
|
||||
for i in range(self.part):
|
||||
feature = aggreateF[:, :, i : i + 1].expand(
|
||||
batch, self.hidden, IHs[i + 1] - IHs[i]
|
||||
)
|
||||
feature = feature.view(batch, self.hidden, IHs[i + 1] - IHs[i], 1)
|
||||
features.append(feature)
|
||||
features = torch.cat(features, dim=2).expand(batch, self.hidden, H, W)
|
||||
final_fea = torch.cat((x, features), dim=1)
|
||||
outputs = self.last(final_fea)
|
||||
return outputs
|
||||
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
if drop_prob > 0.0:
|
||||
keep_prob = 1.0 - drop_prob
|
||||
mask = x.new_zeros(x.size(0), 1, 1, 1)
|
||||
mask = mask.bernoulli_(keep_prob)
|
||||
x = torch.div(x, keep_prob)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours
|
||||
class GDAS_Reduction_Cell(nn.Module):
|
||||
def __init__(
|
||||
self, C_prev_prev, C_prev, C, reduction_prev, affine, track_running_stats
|
||||
):
|
||||
super(GDAS_Reduction_Cell, self).__init__()
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(
|
||||
C_prev_prev, C, 2, affine, track_running_stats
|
||||
)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(
|
||||
C_prev_prev, C, 1, 1, 0, 1, affine, track_running_stats
|
||||
)
|
||||
self.preprocess1 = ReLUConvBN(
|
||||
C_prev, C, 1, 1, 0, 1, affine, track_running_stats
|
||||
)
|
||||
|
||||
self.reduction = True
|
||||
self.ops1 = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C,
|
||||
C,
|
||||
(1, 3),
|
||||
stride=(1, 2),
|
||||
padding=(0, 1),
|
||||
groups=8,
|
||||
bias=not affine,
|
||||
),
|
||||
nn.Conv2d(
|
||||
C,
|
||||
C,
|
||||
(3, 1),
|
||||
stride=(2, 1),
|
||||
padding=(1, 0),
|
||||
groups=8,
|
||||
bias=not affine,
|
||||
),
|
||||
nn.BatchNorm2d(
|
||||
C, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine),
|
||||
nn.BatchNorm2d(
|
||||
C, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
),
|
||||
nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(
|
||||
C,
|
||||
C,
|
||||
(1, 3),
|
||||
stride=(1, 2),
|
||||
padding=(0, 1),
|
||||
groups=8,
|
||||
bias=not affine,
|
||||
),
|
||||
nn.Conv2d(
|
||||
C,
|
||||
C,
|
||||
(3, 1),
|
||||
stride=(2, 1),
|
||||
padding=(1, 0),
|
||||
groups=8,
|
||||
bias=not affine,
|
||||
),
|
||||
nn.BatchNorm2d(
|
||||
C, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C, C, 1, stride=1, padding=0, bias=not affine),
|
||||
nn.BatchNorm2d(
|
||||
C, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.ops2 = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.MaxPool2d(3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(
|
||||
C, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
),
|
||||
nn.Sequential(
|
||||
nn.MaxPool2d(3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(
|
||||
C, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def multiplier(self):
|
||||
return 4
|
||||
|
||||
def forward(self, s0, s1, drop_prob=-1):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
X0 = self.ops1[0](s0)
|
||||
X1 = self.ops1[1](s1)
|
||||
if self.training and drop_prob > 0.0:
|
||||
X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob)
|
||||
|
||||
# X2 = self.ops2[0] (X0+X1)
|
||||
X2 = self.ops2[0](s0)
|
||||
X3 = self.ops2[1](s1)
|
||||
if self.training and drop_prob > 0.0:
|
||||
X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob)
|
||||
return torch.cat([X0, X1, X2, X3], dim=1)
|
||||
|
||||
|
||||
# To manage the useful classes in this file.
|
||||
RAW_OP_CLASSES = {"gdas_reduction": GDAS_Reduction_Cell}
|
||||
33
AutoDL-Projects/xautodl/models/cell_searchs/__init__.py
Normal file
33
AutoDL-Projects/xautodl/models/cell_searchs/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
# The macro structure is defined in NAS-Bench-201
|
||||
from .search_model_darts import TinyNetworkDarts
|
||||
from .search_model_gdas import TinyNetworkGDAS
|
||||
from .search_model_setn import TinyNetworkSETN
|
||||
from .search_model_enas import TinyNetworkENAS
|
||||
from .search_model_random import TinyNetworkRANDOM
|
||||
from .generic_model import GenericNAS201Model
|
||||
from .genotypes import Structure as CellStructure, architectures as CellArchitectures
|
||||
|
||||
# NASNet-based macro structure
|
||||
from .search_model_gdas_nasnet import NASNetworkGDAS
|
||||
from .search_model_gdas_frc_nasnet import NASNetworkGDAS_FRC
|
||||
from .search_model_darts_nasnet import NASNetworkDARTS
|
||||
|
||||
|
||||
nas201_super_nets = {
|
||||
"DARTS-V1": TinyNetworkDarts,
|
||||
"DARTS-V2": TinyNetworkDarts,
|
||||
"GDAS": TinyNetworkGDAS,
|
||||
"SETN": TinyNetworkSETN,
|
||||
"ENAS": TinyNetworkENAS,
|
||||
"RANDOM": TinyNetworkRANDOM,
|
||||
"generic": GenericNAS201Model,
|
||||
}
|
||||
|
||||
nasnet_super_nets = {
|
||||
"GDAS": NASNetworkGDAS,
|
||||
"GDAS_FRC": NASNetworkGDAS_FRC,
|
||||
"DARTS": NASNetworkDARTS,
|
||||
}
|
||||
14
AutoDL-Projects/xautodl/models/cell_searchs/_test_module.py
Normal file
14
AutoDL-Projects/xautodl/models/cell_searchs/_test_module.py
Normal file
@@ -0,0 +1,14 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import torch
|
||||
from search_model_enas_utils import Controller
|
||||
|
||||
|
||||
def main():
|
||||
controller = Controller(6, 4)
|
||||
predictions = controller()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
366
AutoDL-Projects/xautodl/models/cell_searchs/generic_model.py
Normal file
366
AutoDL-Projects/xautodl/models/cell_searchs/generic_model.py
Normal file
@@ -0,0 +1,366 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.07 #
|
||||
#####################################################
|
||||
import torch, random
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from typing import Text
|
||||
from torch.distributions.categorical import Categorical
|
||||
|
||||
from ..cell_operations import ResNetBasicblock, drop_path
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
class Controller(nn.Module):
|
||||
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
|
||||
def __init__(
|
||||
self,
|
||||
edge2index,
|
||||
op_names,
|
||||
max_nodes,
|
||||
lstm_size=32,
|
||||
lstm_num_layers=2,
|
||||
tanh_constant=2.5,
|
||||
temperature=5.0,
|
||||
):
|
||||
super(Controller, self).__init__()
|
||||
# assign the attributes
|
||||
self.max_nodes = max_nodes
|
||||
self.num_edge = len(edge2index)
|
||||
self.edge2index = edge2index
|
||||
self.num_ops = len(op_names)
|
||||
self.op_names = op_names
|
||||
self.lstm_size = lstm_size
|
||||
self.lstm_N = lstm_num_layers
|
||||
self.tanh_constant = tanh_constant
|
||||
self.temperature = temperature
|
||||
# create parameters
|
||||
self.register_parameter(
|
||||
"input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size))
|
||||
)
|
||||
self.w_lstm = nn.LSTM(
|
||||
input_size=self.lstm_size,
|
||||
hidden_size=self.lstm_size,
|
||||
num_layers=self.lstm_N,
|
||||
)
|
||||
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
|
||||
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
|
||||
|
||||
nn.init.uniform_(self.input_vars, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_embd.weight, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_pred.weight, -0.1, 0.1)
|
||||
|
||||
def convert_structure(self, _arch):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op_index = _arch[self.edge2index[node_str]]
|
||||
op_name = self.op_names[op_index]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def forward(self):
|
||||
|
||||
inputs, h0 = self.input_vars, None
|
||||
log_probs, entropys, sampled_arch = [], [], []
|
||||
for iedge in range(self.num_edge):
|
||||
outputs, h0 = self.w_lstm(inputs, h0)
|
||||
|
||||
logits = self.w_pred(outputs)
|
||||
logits = logits / self.temperature
|
||||
logits = self.tanh_constant * torch.tanh(logits)
|
||||
# distribution
|
||||
op_distribution = Categorical(logits=logits)
|
||||
op_index = op_distribution.sample()
|
||||
sampled_arch.append(op_index.item())
|
||||
|
||||
op_log_prob = op_distribution.log_prob(op_index)
|
||||
log_probs.append(op_log_prob.view(-1))
|
||||
op_entropy = op_distribution.entropy()
|
||||
entropys.append(op_entropy.view(-1))
|
||||
|
||||
# obtain the input embedding for the next step
|
||||
inputs = self.w_embd(op_index)
|
||||
return (
|
||||
torch.sum(torch.cat(log_probs)),
|
||||
torch.sum(torch.cat(entropys)),
|
||||
self.convert_structure(sampled_arch),
|
||||
)
|
||||
|
||||
|
||||
class GenericNAS201Model(nn.Module):
|
||||
def __init__(
|
||||
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
|
||||
):
|
||||
super(GenericNAS201Model, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self._max_nodes = max_nodes
|
||||
self._stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self._cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
C_prev,
|
||||
C_curr,
|
||||
1,
|
||||
max_nodes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self._cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self._op_names = deepcopy(search_space)
|
||||
self._Layer = len(self._cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(
|
||||
nn.BatchNorm2d(
|
||||
C_prev, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self._num_edge = num_edge
|
||||
# algorithm related
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self._mode = None
|
||||
self.dynamic_cell = None
|
||||
self._tau = None
|
||||
self._algo = None
|
||||
self._drop_path = None
|
||||
self.verbose = False
|
||||
|
||||
def set_algo(self, algo: Text):
|
||||
# used for searching
|
||||
assert self._algo is None, "This functioin can only be called once."
|
||||
self._algo = algo
|
||||
if algo == "enas":
|
||||
self.controller = Controller(
|
||||
self.edge2index, self._op_names, self._max_nodes
|
||||
)
|
||||
else:
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(self._num_edge, len(self._op_names))
|
||||
)
|
||||
if algo == "gdas":
|
||||
self._tau = 10
|
||||
|
||||
def set_cal_mode(self, mode, dynamic_cell=None):
|
||||
assert mode in ["gdas", "enas", "urs", "joint", "select", "dynamic"]
|
||||
self._mode = mode
|
||||
if mode == "dynamic":
|
||||
self.dynamic_cell = deepcopy(dynamic_cell)
|
||||
else:
|
||||
self.dynamic_cell = None
|
||||
|
||||
def set_drop_path(self, progress, drop_path_rate):
|
||||
if drop_path_rate is None:
|
||||
self._drop_path = None
|
||||
elif progress is None:
|
||||
self._drop_path = drop_path_rate
|
||||
else:
|
||||
self._drop_path = progress * drop_path_rate
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
return self._mode
|
||||
|
||||
@property
|
||||
def drop_path(self):
|
||||
return self._drop_path
|
||||
|
||||
@property
|
||||
def weights(self):
|
||||
xlist = list(self._stem.parameters())
|
||||
xlist += list(self._cells.parameters())
|
||||
xlist += list(self.lastact.parameters())
|
||||
xlist += list(self.global_pooling.parameters())
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def set_tau(self, tau):
|
||||
self._tau = tau
|
||||
|
||||
@property
|
||||
def tau(self):
|
||||
return self._tau
|
||||
|
||||
@property
|
||||
def alphas(self):
|
||||
if self._algo == "enas":
|
||||
return list(self.controller.parameters())
|
||||
else:
|
||||
return [self.arch_parameters]
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self._cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self._cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
if self._algo == "enas":
|
||||
return "w_pred :\n{:}".format(self.controller.w_pred.weight)
|
||||
else:
|
||||
return "arch-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
|
||||
)
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, Max-Nodes={_max_nodes}, N={_layerN}, L={_Layer}, alg={_algo})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
@property
|
||||
def genotype(self):
|
||||
genotypes = []
|
||||
for i in range(1, self._max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
with torch.no_grad():
|
||||
weights = self.arch_parameters[self.edge2index[node_str]]
|
||||
op_name = self._op_names[weights.argmax().item()]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def dync_genotype(self, use_random=False):
|
||||
genotypes = []
|
||||
with torch.no_grad():
|
||||
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
for i in range(1, self._max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
if use_random:
|
||||
op_name = random.choice(self._op_names)
|
||||
else:
|
||||
weights = alphas_cpu[self.edge2index[node_str]]
|
||||
op_index = torch.multinomial(weights, 1).item()
|
||||
op_name = self._op_names[op_index]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def get_log_prob(self, arch):
|
||||
with torch.no_grad():
|
||||
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
|
||||
select_logits = []
|
||||
for i, node_info in enumerate(arch.nodes):
|
||||
for op, xin in node_info:
|
||||
node_str = "{:}<-{:}".format(i + 1, xin)
|
||||
op_index = self._op_names.index(op)
|
||||
select_logits.append(logits[self.edge2index[node_str], op_index])
|
||||
return sum(select_logits).item()
|
||||
|
||||
def return_topK(self, K, use_random=False):
|
||||
archs = Structure.gen_all(self._op_names, self._max_nodes, False)
|
||||
pairs = [(self.get_log_prob(arch), arch) for arch in archs]
|
||||
if K < 0 or K >= len(archs):
|
||||
K = len(archs)
|
||||
if use_random:
|
||||
return random.sample(archs, K)
|
||||
else:
|
||||
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
|
||||
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
|
||||
return return_pairs
|
||||
|
||||
def normalize_archp(self):
|
||||
if self.mode == "gdas":
|
||||
while True:
|
||||
gumbels = -torch.empty_like(self.arch_parameters).exponential_().log()
|
||||
logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau
|
||||
probs = nn.functional.softmax(logits, dim=1)
|
||||
index = probs.max(-1, keepdim=True)[1]
|
||||
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
||||
hardwts = one_h - probs.detach() + probs
|
||||
if (
|
||||
(torch.isinf(gumbels).any())
|
||||
or (torch.isinf(probs).any())
|
||||
or (torch.isnan(probs).any())
|
||||
):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
with torch.no_grad():
|
||||
hardwts_cpu = hardwts.detach().cpu()
|
||||
return hardwts, hardwts_cpu, index, "GUMBEL"
|
||||
else:
|
||||
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
index = alphas.max(-1, keepdim=True)[1]
|
||||
with torch.no_grad():
|
||||
alphas_cpu = alphas.detach().cpu()
|
||||
return alphas, alphas_cpu, index, "SOFTMAX"
|
||||
|
||||
def forward(self, inputs):
|
||||
alphas, alphas_cpu, index, verbose_str = self.normalize_archp()
|
||||
feature = self._stem(inputs)
|
||||
for i, cell in enumerate(self._cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
if self.mode == "urs":
|
||||
feature = cell.forward_urs(feature)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_urs"
|
||||
elif self.mode == "select":
|
||||
feature = cell.forward_select(feature, alphas_cpu)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_select"
|
||||
elif self.mode == "joint":
|
||||
feature = cell.forward_joint(feature, alphas)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_joint"
|
||||
elif self.mode == "dynamic":
|
||||
feature = cell.forward_dynamic(feature, self.dynamic_cell)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_dynamic"
|
||||
elif self.mode == "gdas":
|
||||
feature = cell.forward_gdas(feature, alphas, index)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_gdas"
|
||||
elif self.mode == "gdas_v1":
|
||||
feature = cell.forward_gdas_v1(feature, alphas, index)
|
||||
if self.verbose:
|
||||
verbose_str += "-forward_gdas_v1"
|
||||
else:
|
||||
raise ValueError("invalid mode={:}".format(self.mode))
|
||||
else:
|
||||
feature = cell(feature)
|
||||
if self.drop_path is not None:
|
||||
feature = drop_path(feature, self.drop_path)
|
||||
if self.verbose and random.random() < 0.001:
|
||||
print(verbose_str)
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
return out, logits
|
||||
274
AutoDL-Projects/xautodl/models/cell_searchs/genotypes.py
Normal file
274
AutoDL-Projects/xautodl/models/cell_searchs/genotypes.py
Normal file
@@ -0,0 +1,274 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def get_combination(space, num):
|
||||
combs = []
|
||||
for i in range(num):
|
||||
if i == 0:
|
||||
for func in space:
|
||||
combs.append([(func, i)])
|
||||
else:
|
||||
new_combs = []
|
||||
for string in combs:
|
||||
for func in space:
|
||||
xstring = string + [(func, i)]
|
||||
new_combs.append(xstring)
|
||||
combs = new_combs
|
||||
return combs
|
||||
|
||||
|
||||
class Structure:
|
||||
def __init__(self, genotype):
|
||||
assert isinstance(genotype, list) or isinstance(
|
||||
genotype, tuple
|
||||
), "invalid class of genotype : {:}".format(type(genotype))
|
||||
self.node_num = len(genotype) + 1
|
||||
self.nodes = []
|
||||
self.node_N = []
|
||||
for idx, node_info in enumerate(genotype):
|
||||
assert isinstance(node_info, list) or isinstance(
|
||||
node_info, tuple
|
||||
), "invalid class of node_info : {:}".format(type(node_info))
|
||||
assert len(node_info) >= 1, "invalid length : {:}".format(len(node_info))
|
||||
for node_in in node_info:
|
||||
assert isinstance(node_in, list) or isinstance(
|
||||
node_in, tuple
|
||||
), "invalid class of in-node : {:}".format(type(node_in))
|
||||
assert (
|
||||
len(node_in) == 2 and node_in[1] <= idx
|
||||
), "invalid in-node : {:}".format(node_in)
|
||||
self.node_N.append(len(node_info))
|
||||
self.nodes.append(tuple(deepcopy(node_info)))
|
||||
|
||||
def tolist(self, remove_str):
|
||||
# convert this class to the list, if remove_str is 'none', then remove the 'none' operation.
|
||||
# note that we re-order the input node in this function
|
||||
# return the-genotype-list and success [if unsuccess, it is not a connectivity]
|
||||
genotypes = []
|
||||
for node_info in self.nodes:
|
||||
node_info = list(node_info)
|
||||
node_info = sorted(node_info, key=lambda x: (x[1], x[0]))
|
||||
node_info = tuple(filter(lambda x: x[0] != remove_str, node_info))
|
||||
if len(node_info) == 0:
|
||||
return None, False
|
||||
genotypes.append(node_info)
|
||||
return genotypes, True
|
||||
|
||||
def node(self, index):
|
||||
assert index > 0 and index <= len(self), "invalid index={:} < {:}".format(
|
||||
index, len(self)
|
||||
)
|
||||
return self.nodes[index]
|
||||
|
||||
def tostr(self):
|
||||
strings = []
|
||||
for node_info in self.nodes:
|
||||
string = "|".join([x[0] + "~{:}".format(x[1]) for x in node_info])
|
||||
string = "|{:}|".format(string)
|
||||
strings.append(string)
|
||||
return "+".join(strings)
|
||||
|
||||
def check_valid(self):
|
||||
nodes = {0: True}
|
||||
for i, node_info in enumerate(self.nodes):
|
||||
sums = []
|
||||
for op, xin in node_info:
|
||||
if op == "none" or nodes[xin] is False:
|
||||
x = False
|
||||
else:
|
||||
x = True
|
||||
sums.append(x)
|
||||
nodes[i + 1] = sum(sums) > 0
|
||||
return nodes[len(self.nodes)]
|
||||
|
||||
def to_unique_str(self, consider_zero=False):
|
||||
# this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation
|
||||
# two operations are special, i.e., none and skip_connect
|
||||
nodes = {0: "0"}
|
||||
for i_node, node_info in enumerate(self.nodes):
|
||||
cur_node = []
|
||||
for op, xin in node_info:
|
||||
if consider_zero is None:
|
||||
x = "(" + nodes[xin] + ")" + "@{:}".format(op)
|
||||
elif consider_zero:
|
||||
if op == "none" or nodes[xin] == "#":
|
||||
x = "#" # zero
|
||||
elif op == "skip_connect":
|
||||
x = nodes[xin]
|
||||
else:
|
||||
x = "(" + nodes[xin] + ")" + "@{:}".format(op)
|
||||
else:
|
||||
if op == "skip_connect":
|
||||
x = nodes[xin]
|
||||
else:
|
||||
x = "(" + nodes[xin] + ")" + "@{:}".format(op)
|
||||
cur_node.append(x)
|
||||
nodes[i_node + 1] = "+".join(sorted(cur_node))
|
||||
return nodes[len(self.nodes)]
|
||||
|
||||
def check_valid_op(self, op_names):
|
||||
for node_info in self.nodes:
|
||||
for inode_edge in node_info:
|
||||
# assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0])
|
||||
if inode_edge[0] not in op_names:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({node_num} nodes with {node_info})".format(
|
||||
name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.nodes) + 1
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.nodes[index]
|
||||
|
||||
@staticmethod
|
||||
def str2structure(xstr):
|
||||
if isinstance(xstr, Structure):
|
||||
return xstr
|
||||
assert isinstance(xstr, str), "must take string (not {:}) as input".format(
|
||||
type(xstr)
|
||||
)
|
||||
nodestrs = xstr.split("+")
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(nodestrs):
|
||||
inputs = list(filter(lambda x: x != "", node_str.split("|")))
|
||||
for xinput in inputs:
|
||||
assert len(xinput.split("~")) == 2, "invalid input length : {:}".format(
|
||||
xinput
|
||||
)
|
||||
inputs = (xi.split("~") for xi in inputs)
|
||||
input_infos = tuple((op, int(IDX)) for (op, IDX) in inputs)
|
||||
genotypes.append(input_infos)
|
||||
return Structure(genotypes)
|
||||
|
||||
@staticmethod
|
||||
def str2fullstructure(xstr, default_name="none"):
|
||||
assert isinstance(xstr, str), "must take string (not {:}) as input".format(
|
||||
type(xstr)
|
||||
)
|
||||
nodestrs = xstr.split("+")
|
||||
genotypes = []
|
||||
for i, node_str in enumerate(nodestrs):
|
||||
inputs = list(filter(lambda x: x != "", node_str.split("|")))
|
||||
for xinput in inputs:
|
||||
assert len(xinput.split("~")) == 2, "invalid input length : {:}".format(
|
||||
xinput
|
||||
)
|
||||
inputs = (xi.split("~") for xi in inputs)
|
||||
input_infos = list((op, int(IDX)) for (op, IDX) in inputs)
|
||||
all_in_nodes = list(x[1] for x in input_infos)
|
||||
for j in range(i):
|
||||
if j not in all_in_nodes:
|
||||
input_infos.append((default_name, j))
|
||||
node_info = sorted(input_infos, key=lambda x: (x[1], x[0]))
|
||||
genotypes.append(tuple(node_info))
|
||||
return Structure(genotypes)
|
||||
|
||||
@staticmethod
|
||||
def gen_all(search_space, num, return_ori):
|
||||
assert isinstance(search_space, list) or isinstance(
|
||||
search_space, tuple
|
||||
), "invalid class of search-space : {:}".format(type(search_space))
|
||||
assert (
|
||||
num >= 2
|
||||
), "There should be at least two nodes in a neural cell instead of {:}".format(
|
||||
num
|
||||
)
|
||||
all_archs = get_combination(search_space, 1)
|
||||
for i, arch in enumerate(all_archs):
|
||||
all_archs[i] = [tuple(arch)]
|
||||
|
||||
for inode in range(2, num):
|
||||
cur_nodes = get_combination(search_space, inode)
|
||||
new_all_archs = []
|
||||
for previous_arch in all_archs:
|
||||
for cur_node in cur_nodes:
|
||||
new_all_archs.append(previous_arch + [tuple(cur_node)])
|
||||
all_archs = new_all_archs
|
||||
if return_ori:
|
||||
return all_archs
|
||||
else:
|
||||
return [Structure(x) for x in all_archs]
|
||||
|
||||
|
||||
ResNet_CODE = Structure(
|
||||
[
|
||||
(("nor_conv_3x3", 0),), # node-1
|
||||
(("nor_conv_3x3", 1),), # node-2
|
||||
(("skip_connect", 0), ("skip_connect", 2)),
|
||||
] # node-3
|
||||
)
|
||||
|
||||
AllConv3x3_CODE = Structure(
|
||||
[
|
||||
(("nor_conv_3x3", 0),), # node-1
|
||||
(("nor_conv_3x3", 0), ("nor_conv_3x3", 1)), # node-2
|
||||
(("nor_conv_3x3", 0), ("nor_conv_3x3", 1), ("nor_conv_3x3", 2)),
|
||||
] # node-3
|
||||
)
|
||||
|
||||
AllFull_CODE = Structure(
|
||||
[
|
||||
(
|
||||
("skip_connect", 0),
|
||||
("nor_conv_1x1", 0),
|
||||
("nor_conv_3x3", 0),
|
||||
("avg_pool_3x3", 0),
|
||||
), # node-1
|
||||
(
|
||||
("skip_connect", 0),
|
||||
("nor_conv_1x1", 0),
|
||||
("nor_conv_3x3", 0),
|
||||
("avg_pool_3x3", 0),
|
||||
("skip_connect", 1),
|
||||
("nor_conv_1x1", 1),
|
||||
("nor_conv_3x3", 1),
|
||||
("avg_pool_3x3", 1),
|
||||
), # node-2
|
||||
(
|
||||
("skip_connect", 0),
|
||||
("nor_conv_1x1", 0),
|
||||
("nor_conv_3x3", 0),
|
||||
("avg_pool_3x3", 0),
|
||||
("skip_connect", 1),
|
||||
("nor_conv_1x1", 1),
|
||||
("nor_conv_3x3", 1),
|
||||
("avg_pool_3x3", 1),
|
||||
("skip_connect", 2),
|
||||
("nor_conv_1x1", 2),
|
||||
("nor_conv_3x3", 2),
|
||||
("avg_pool_3x3", 2),
|
||||
),
|
||||
] # node-3
|
||||
)
|
||||
|
||||
AllConv1x1_CODE = Structure(
|
||||
[
|
||||
(("nor_conv_1x1", 0),), # node-1
|
||||
(("nor_conv_1x1", 0), ("nor_conv_1x1", 1)), # node-2
|
||||
(("nor_conv_1x1", 0), ("nor_conv_1x1", 1), ("nor_conv_1x1", 2)),
|
||||
] # node-3
|
||||
)
|
||||
|
||||
AllIdentity_CODE = Structure(
|
||||
[
|
||||
(("skip_connect", 0),), # node-1
|
||||
(("skip_connect", 0), ("skip_connect", 1)), # node-2
|
||||
(("skip_connect", 0), ("skip_connect", 1), ("skip_connect", 2)),
|
||||
] # node-3
|
||||
)
|
||||
|
||||
architectures = {
|
||||
"resnet": ResNet_CODE,
|
||||
"all_c3x3": AllConv3x3_CODE,
|
||||
"all_c1x1": AllConv1x1_CODE,
|
||||
"all_idnt": AllIdentity_CODE,
|
||||
"all_full": AllFull_CODE,
|
||||
}
|
||||
267
AutoDL-Projects/xautodl/models/cell_searchs/search_cells.py
Normal file
267
AutoDL-Projects/xautodl/models/cell_searchs/search_cells.py
Normal file
@@ -0,0 +1,267 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import math, random, torch
|
||||
import warnings
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import OPS
|
||||
|
||||
|
||||
# This module is used for NAS-Bench-201, represents a small search space with a complete DAG
|
||||
class NAS201SearchCell(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C_in,
|
||||
C_out,
|
||||
stride,
|
||||
max_nodes,
|
||||
op_names,
|
||||
affine=False,
|
||||
track_running_stats=True,
|
||||
):
|
||||
super(NAS201SearchCell, self).__init__()
|
||||
|
||||
self.op_names = deepcopy(op_names)
|
||||
self.edges = nn.ModuleDict()
|
||||
self.max_nodes = max_nodes
|
||||
self.in_dim = C_in
|
||||
self.out_dim = C_out
|
||||
for i in range(1, max_nodes):
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
if j == 0:
|
||||
xlists = [
|
||||
OPS[op_name](C_in, C_out, stride, affine, track_running_stats)
|
||||
for op_name in op_names
|
||||
]
|
||||
else:
|
||||
xlists = [
|
||||
OPS[op_name](C_in, C_out, 1, affine, track_running_stats)
|
||||
for op_name in op_names
|
||||
]
|
||||
self.edges[node_str] = nn.ModuleList(xlists)
|
||||
self.edge_keys = sorted(list(self.edges.keys()))
|
||||
self.edge2index = {key: i for i, key in enumerate(self.edge_keys)}
|
||||
self.num_edges = len(self.edges)
|
||||
|
||||
def extra_repr(self):
|
||||
string = "info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}".format(
|
||||
**self.__dict__
|
||||
)
|
||||
return string
|
||||
|
||||
def forward(self, inputs, weightss):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
weights = weightss[self.edge2index[node_str]]
|
||||
inter_nodes.append(
|
||||
sum(
|
||||
layer(nodes[j]) * w
|
||||
for layer, w in zip(self.edges[node_str], weights)
|
||||
)
|
||||
)
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# GDAS
|
||||
def forward_gdas(self, inputs, hardwts, index):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
weights = hardwts[self.edge2index[node_str]]
|
||||
argmaxs = index[self.edge2index[node_str]].item()
|
||||
weigsum = sum(
|
||||
weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie]
|
||||
for _ie, edge in enumerate(self.edges[node_str])
|
||||
)
|
||||
inter_nodes.append(weigsum)
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# GDAS Variant: https://github.com/D-X-Y/AutoDL-Projects/issues/119
|
||||
def forward_gdas_v1(self, inputs, hardwts, index):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
weights = hardwts[self.edge2index[node_str]]
|
||||
argmaxs = index[self.edge2index[node_str]].item()
|
||||
weigsum = weights[argmaxs] * self.edges[node_str](nodes[j])
|
||||
inter_nodes.append(weigsum)
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# joint
|
||||
def forward_joint(self, inputs, weightss):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
weights = weightss[self.edge2index[node_str]]
|
||||
# aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel()
|
||||
aggregation = sum(
|
||||
layer(nodes[j]) * w
|
||||
for layer, w in zip(self.edges[node_str], weights)
|
||||
)
|
||||
inter_nodes.append(aggregation)
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# uniform random sampling per iteration, SETN
|
||||
def forward_urs(self, inputs):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
while True: # to avoid select zero for all ops
|
||||
sops, has_non_zero = [], False
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
candidates = self.edges[node_str]
|
||||
select_op = random.choice(candidates)
|
||||
sops.append(select_op)
|
||||
if not hasattr(select_op, "is_zero") or select_op.is_zero is False:
|
||||
has_non_zero = True
|
||||
if has_non_zero:
|
||||
break
|
||||
inter_nodes = []
|
||||
for j, select_op in enumerate(sops):
|
||||
inter_nodes.append(select_op(nodes[j]))
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# select the argmax
|
||||
def forward_select(self, inputs, weightss):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
inter_nodes = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
weights = weightss[self.edge2index[node_str]]
|
||||
inter_nodes.append(
|
||||
self.edges[node_str][weights.argmax().item()](nodes[j])
|
||||
)
|
||||
# inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) )
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
# forward with a specific structure
|
||||
def forward_dynamic(self, inputs, structure):
|
||||
nodes = [inputs]
|
||||
for i in range(1, self.max_nodes):
|
||||
cur_op_node = structure.nodes[i - 1]
|
||||
inter_nodes = []
|
||||
for op_name, j in cur_op_node:
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op_index = self.op_names.index(op_name)
|
||||
inter_nodes.append(self.edges[node_str][op_index](nodes[j]))
|
||||
nodes.append(sum(inter_nodes))
|
||||
return nodes[-1]
|
||||
|
||||
|
||||
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
def __init__(self, space, C, stride, affine, track_running_stats):
|
||||
super(MixedOp, self).__init__()
|
||||
self._ops = nn.ModuleList()
|
||||
for primitive in space:
|
||||
op = OPS[primitive](C, C, stride, affine, track_running_stats)
|
||||
self._ops.append(op)
|
||||
|
||||
def forward_gdas(self, x, weights, index):
|
||||
return self._ops[index](x) * weights[index]
|
||||
|
||||
def forward_darts(self, x, weights):
|
||||
return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||||
|
||||
|
||||
class NASNetSearchCell(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
space,
|
||||
steps,
|
||||
multiplier,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
):
|
||||
super(NASNetSearchCell, self).__init__()
|
||||
self.reduction = reduction
|
||||
self.op_names = deepcopy(space)
|
||||
if reduction_prev:
|
||||
self.preprocess0 = OPS["skip_connect"](
|
||||
C_prev_prev, C, 2, affine, track_running_stats
|
||||
)
|
||||
else:
|
||||
self.preprocess0 = OPS["nor_conv_1x1"](
|
||||
C_prev_prev, C, 1, affine, track_running_stats
|
||||
)
|
||||
self.preprocess1 = OPS["nor_conv_1x1"](
|
||||
C_prev, C, 1, affine, track_running_stats
|
||||
)
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
self.edges = nn.ModuleDict()
|
||||
for i in range(self._steps):
|
||||
for j in range(2 + i):
|
||||
node_str = "{:}<-{:}".format(
|
||||
i, j
|
||||
) # indicate the edge from node-(j) to node-(i+2)
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
op = MixedOp(space, C, stride, affine, track_running_stats)
|
||||
self.edges[node_str] = op
|
||||
self.edge_keys = sorted(list(self.edges.keys()))
|
||||
self.edge2index = {key: i for i, key in enumerate(self.edge_keys)}
|
||||
self.num_edges = len(self.edges)
|
||||
|
||||
@property
|
||||
def multiplier(self):
|
||||
return self._multiplier
|
||||
|
||||
def forward_gdas(self, s0, s1, weightss, indexs):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op = self.edges[node_str]
|
||||
weights = weightss[self.edge2index[node_str]]
|
||||
index = indexs[self.edge2index[node_str]].item()
|
||||
clist.append(op.forward_gdas(h, weights, index))
|
||||
states.append(sum(clist))
|
||||
|
||||
return torch.cat(states[-self._multiplier :], dim=1)
|
||||
|
||||
def forward_darts(self, s0, s1, weightss):
|
||||
s0 = self.preprocess0(s0)
|
||||
s1 = self.preprocess1(s1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op = self.edges[node_str]
|
||||
weights = weightss[self.edge2index[node_str]]
|
||||
clist.append(op.forward_darts(h, weights))
|
||||
states.append(sum(clist))
|
||||
|
||||
return torch.cat(states[-self._multiplier :], dim=1)
|
||||
@@ -0,0 +1,122 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
########################################################
|
||||
# DARTS: Differentiable Architecture Search, ICLR 2019 #
|
||||
########################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
class TinyNetworkDarts(nn.Module):
|
||||
def __init__(
|
||||
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
|
||||
):
|
||||
super(TinyNetworkDarts, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
C_prev,
|
||||
C_curr,
|
||||
1,
|
||||
max_nodes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_parameters]
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
return "arch-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
|
||||
)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def genotype(self):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
with torch.no_grad():
|
||||
weights = self.arch_parameters[self.edge2index[node_str]]
|
||||
op_name = self.op_names[weights.argmax().item()]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def forward(self, inputs):
|
||||
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell(feature, alphas)
|
||||
else:
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
@@ -0,0 +1,178 @@
|
||||
####################
|
||||
# DARTS, ICLR 2019 #
|
||||
####################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from typing import List, Text, Dict
|
||||
from .search_cells import NASNetSearchCell as SearchCell
|
||||
|
||||
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetworkDARTS(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C: int,
|
||||
N: int,
|
||||
steps: int,
|
||||
multiplier: int,
|
||||
stem_multiplier: int,
|
||||
num_classes: int,
|
||||
search_space: List[Text],
|
||||
affine: bool,
|
||||
track_running_stats: bool,
|
||||
):
|
||||
super(NASNetworkDARTS, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C * stem_multiplier),
|
||||
)
|
||||
|
||||
# config for each layer
|
||||
layer_channels = (
|
||||
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
|
||||
)
|
||||
layer_reductions = (
|
||||
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
|
||||
)
|
||||
|
||||
num_edge, edge2index = None, None
|
||||
C_prev_prev, C_prev, C_curr, reduction_prev = (
|
||||
C * stem_multiplier,
|
||||
C * stem_multiplier,
|
||||
C,
|
||||
False,
|
||||
)
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
cell = SearchCell(
|
||||
search_space,
|
||||
steps,
|
||||
multiplier,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C_curr,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_normal_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.arch_reduce_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
|
||||
def get_weights(self) -> List[torch.nn.Parameter]:
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def get_alphas(self) -> List[torch.nn.Parameter]:
|
||||
return [self.arch_normal_parameters, self.arch_reduce_parameters]
|
||||
|
||||
def show_alphas(self) -> Text:
|
||||
with torch.no_grad():
|
||||
A = "arch-normal-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu()
|
||||
)
|
||||
B = "arch-reduce-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu()
|
||||
)
|
||||
return "{:}\n{:}".format(A, B)
|
||||
|
||||
def get_message(self) -> Text:
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self) -> Text:
|
||||
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def genotype(self) -> Dict[Text, List]:
|
||||
def _parse(weights):
|
||||
gene = []
|
||||
for i in range(self._steps):
|
||||
edges = []
|
||||
for j in range(2 + i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
ws = weights[self.edge2index[node_str]]
|
||||
for k, op_name in enumerate(self.op_names):
|
||||
if op_name == "none":
|
||||
continue
|
||||
edges.append((op_name, j, ws[k]))
|
||||
# (TODO) xuanyidong:
|
||||
# Here the selected two edges might come from the same input node.
|
||||
# And this case could be a problem that two edges will collapse into a single one
|
||||
# due to our assumption -- at most one edge from an input node during evaluation.
|
||||
edges = sorted(edges, key=lambda x: -x[-1])
|
||||
selected_edges = edges[:2]
|
||||
gene.append(tuple(selected_edges))
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(
|
||||
torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
gene_reduce = _parse(
|
||||
torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
return {
|
||||
"normal": gene_normal,
|
||||
"normal_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
"reduce": gene_reduce,
|
||||
"reduce_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
}
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
normal_w = nn.functional.softmax(self.arch_normal_parameters, dim=1)
|
||||
reduce_w = nn.functional.softmax(self.arch_reduce_parameters, dim=1)
|
||||
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
ww = reduce_w
|
||||
else:
|
||||
ww = normal_w
|
||||
s0, s1 = s1, cell.forward_darts(s0, s1, ww)
|
||||
out = self.lastact(s1)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
114
AutoDL-Projects/xautodl/models/cell_searchs/search_model_enas.py
Normal file
114
AutoDL-Projects/xautodl/models/cell_searchs/search_model_enas.py
Normal file
@@ -0,0 +1,114 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##########################################################################
|
||||
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
|
||||
##########################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
from .search_model_enas_utils import Controller
|
||||
|
||||
|
||||
class TinyNetworkENAS(nn.Module):
|
||||
def __init__(
|
||||
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
|
||||
):
|
||||
super(TinyNetworkENAS, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
C_prev,
|
||||
C_curr,
|
||||
1,
|
||||
max_nodes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
# to maintain the sampled architecture
|
||||
self.sampled_arch = None
|
||||
|
||||
def update_arch(self, _arch):
|
||||
if _arch is None:
|
||||
self.sampled_arch = None
|
||||
elif isinstance(_arch, Structure):
|
||||
self.sampled_arch = _arch
|
||||
elif isinstance(_arch, (list, tuple)):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op_index = _arch[self.edge2index[node_str]]
|
||||
op_name = self.op_names[op_index]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
self.sampled_arch = Structure(genotypes)
|
||||
else:
|
||||
raise ValueError("invalid type of input architecture : {:}".format(_arch))
|
||||
return self.sampled_arch
|
||||
|
||||
def create_controller(self):
|
||||
return Controller(len(self.edge2index), len(self.op_names))
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell.forward_dynamic(feature, self.sampled_arch)
|
||||
else:
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
@@ -0,0 +1,74 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##########################################################################
|
||||
# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 #
|
||||
##########################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributions.categorical import Categorical
|
||||
|
||||
|
||||
class Controller(nn.Module):
|
||||
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
|
||||
def __init__(
|
||||
self,
|
||||
num_edge,
|
||||
num_ops,
|
||||
lstm_size=32,
|
||||
lstm_num_layers=2,
|
||||
tanh_constant=2.5,
|
||||
temperature=5.0,
|
||||
):
|
||||
super(Controller, self).__init__()
|
||||
# assign the attributes
|
||||
self.num_edge = num_edge
|
||||
self.num_ops = num_ops
|
||||
self.lstm_size = lstm_size
|
||||
self.lstm_N = lstm_num_layers
|
||||
self.tanh_constant = tanh_constant
|
||||
self.temperature = temperature
|
||||
# create parameters
|
||||
self.register_parameter(
|
||||
"input_vars", nn.Parameter(torch.Tensor(1, 1, lstm_size))
|
||||
)
|
||||
self.w_lstm = nn.LSTM(
|
||||
input_size=self.lstm_size,
|
||||
hidden_size=self.lstm_size,
|
||||
num_layers=self.lstm_N,
|
||||
)
|
||||
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
|
||||
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
|
||||
|
||||
nn.init.uniform_(self.input_vars, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_embd.weight, -0.1, 0.1)
|
||||
nn.init.uniform_(self.w_pred.weight, -0.1, 0.1)
|
||||
|
||||
def forward(self):
|
||||
|
||||
inputs, h0 = self.input_vars, None
|
||||
log_probs, entropys, sampled_arch = [], [], []
|
||||
for iedge in range(self.num_edge):
|
||||
outputs, h0 = self.w_lstm(inputs, h0)
|
||||
|
||||
logits = self.w_pred(outputs)
|
||||
logits = logits / self.temperature
|
||||
logits = self.tanh_constant * torch.tanh(logits)
|
||||
# distribution
|
||||
op_distribution = Categorical(logits=logits)
|
||||
op_index = op_distribution.sample()
|
||||
sampled_arch.append(op_index.item())
|
||||
|
||||
op_log_prob = op_distribution.log_prob(op_index)
|
||||
log_probs.append(op_log_prob.view(-1))
|
||||
op_entropy = op_distribution.entropy()
|
||||
entropys.append(op_entropy.view(-1))
|
||||
|
||||
# obtain the input embedding for the next step
|
||||
inputs = self.w_embd(op_index)
|
||||
return (
|
||||
torch.sum(torch.cat(log_probs)),
|
||||
torch.sum(torch.cat(entropys)),
|
||||
sampled_arch,
|
||||
)
|
||||
142
AutoDL-Projects/xautodl/models/cell_searchs/search_model_gdas.py
Normal file
142
AutoDL-Projects/xautodl/models/cell_searchs/search_model_gdas.py
Normal file
@@ -0,0 +1,142 @@
|
||||
###########################################################################
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
|
||||
###########################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
class TinyNetworkGDAS(nn.Module):
|
||||
|
||||
# def __init__(self, C, N, max_nodes, num_classes, search_space, affine=False, track_running_stats=True):
|
||||
def __init__(
|
||||
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
|
||||
):
|
||||
super(TinyNetworkGDAS, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
C_prev,
|
||||
C_curr,
|
||||
1,
|
||||
max_nodes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.tau = 10
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def set_tau(self, tau):
|
||||
self.tau = tau
|
||||
|
||||
def get_tau(self):
|
||||
return self.tau
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_parameters]
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
return "arch-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
|
||||
)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def genotype(self):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
with torch.no_grad():
|
||||
weights = self.arch_parameters[self.edge2index[node_str]]
|
||||
op_name = self.op_names[weights.argmax().item()]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def forward(self, inputs):
|
||||
while True:
|
||||
gumbels = -torch.empty_like(self.arch_parameters).exponential_().log()
|
||||
logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau
|
||||
probs = nn.functional.softmax(logits, dim=1)
|
||||
index = probs.max(-1, keepdim=True)[1]
|
||||
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
||||
hardwts = one_h - probs.detach() + probs
|
||||
if (
|
||||
(torch.isinf(gumbels).any())
|
||||
or (torch.isinf(probs).any())
|
||||
or (torch.isnan(probs).any())
|
||||
):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell.forward_gdas(feature, hardwts, index)
|
||||
else:
|
||||
feature = cell(feature)
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
@@ -0,0 +1,200 @@
|
||||
###########################################################################
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
|
||||
###########################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
|
||||
from .search_cells import NASNetSearchCell as SearchCell
|
||||
from ..cell_operations import RAW_OP_CLASSES
|
||||
|
||||
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetworkGDAS_FRC(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C,
|
||||
N,
|
||||
steps,
|
||||
multiplier,
|
||||
stem_multiplier,
|
||||
num_classes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
):
|
||||
super(NASNetworkGDAS_FRC, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C * stem_multiplier),
|
||||
)
|
||||
|
||||
# config for each layer
|
||||
layer_channels = (
|
||||
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
|
||||
)
|
||||
layer_reductions = (
|
||||
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
|
||||
)
|
||||
|
||||
num_edge, edge2index = None, None
|
||||
C_prev_prev, C_prev, C_curr, reduction_prev = (
|
||||
C * stem_multiplier,
|
||||
C * stem_multiplier,
|
||||
C,
|
||||
False,
|
||||
)
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = RAW_OP_CLASSES["gdas_reduction"](
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C_curr,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
search_space,
|
||||
steps,
|
||||
multiplier,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C_curr,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
reduction
|
||||
or num_edge == cell.num_edges
|
||||
and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev_prev, C_prev, reduction_prev = (
|
||||
C_prev,
|
||||
cell.multiplier * C_curr,
|
||||
reduction,
|
||||
)
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.tau = 10
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def set_tau(self, tau):
|
||||
self.tau = tau
|
||||
|
||||
def get_tau(self):
|
||||
return self.tau
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_parameters]
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
A = "arch-normal-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_parameters, dim=-1).cpu()
|
||||
)
|
||||
return "{:}".format(A)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def genotype(self):
|
||||
def _parse(weights):
|
||||
gene = []
|
||||
for i in range(self._steps):
|
||||
edges = []
|
||||
for j in range(2 + i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
ws = weights[self.edge2index[node_str]]
|
||||
for k, op_name in enumerate(self.op_names):
|
||||
if op_name == "none":
|
||||
continue
|
||||
edges.append((op_name, j, ws[k]))
|
||||
edges = sorted(edges, key=lambda x: -x[-1])
|
||||
selected_edges = edges[:2]
|
||||
gene.append(tuple(selected_edges))
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(
|
||||
torch.softmax(self.arch_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
return {
|
||||
"normal": gene_normal,
|
||||
"normal_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
}
|
||||
|
||||
def forward(self, inputs):
|
||||
def get_gumbel_prob(xins):
|
||||
while True:
|
||||
gumbels = -torch.empty_like(xins).exponential_().log()
|
||||
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau
|
||||
probs = nn.functional.softmax(logits, dim=1)
|
||||
index = probs.max(-1, keepdim=True)[1]
|
||||
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
||||
hardwts = one_h - probs.detach() + probs
|
||||
if (
|
||||
(torch.isinf(gumbels).any())
|
||||
or (torch.isinf(probs).any())
|
||||
or (torch.isnan(probs).any())
|
||||
):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
return hardwts, index
|
||||
|
||||
hardwts, index = get_gumbel_prob(self.arch_parameters)
|
||||
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
s0, s1 = s1, cell(s0, s1)
|
||||
else:
|
||||
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
|
||||
out = self.lastact(s1)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
@@ -0,0 +1,197 @@
|
||||
###########################################################################
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
|
||||
###########################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from .search_cells import NASNetSearchCell as SearchCell
|
||||
|
||||
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetworkGDAS(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C,
|
||||
N,
|
||||
steps,
|
||||
multiplier,
|
||||
stem_multiplier,
|
||||
num_classes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
):
|
||||
super(NASNetworkGDAS, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C * stem_multiplier),
|
||||
)
|
||||
|
||||
# config for each layer
|
||||
layer_channels = (
|
||||
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
|
||||
)
|
||||
layer_reductions = (
|
||||
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
|
||||
)
|
||||
|
||||
num_edge, edge2index = None, None
|
||||
C_prev_prev, C_prev, C_curr, reduction_prev = (
|
||||
C * stem_multiplier,
|
||||
C * stem_multiplier,
|
||||
C,
|
||||
False,
|
||||
)
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
cell = SearchCell(
|
||||
search_space,
|
||||
steps,
|
||||
multiplier,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C_curr,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_normal_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.arch_reduce_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.tau = 10
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def set_tau(self, tau):
|
||||
self.tau = tau
|
||||
|
||||
def get_tau(self):
|
||||
return self.tau
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_normal_parameters, self.arch_reduce_parameters]
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
A = "arch-normal-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu()
|
||||
)
|
||||
B = "arch-reduce-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu()
|
||||
)
|
||||
return "{:}\n{:}".format(A, B)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def genotype(self):
|
||||
def _parse(weights):
|
||||
gene = []
|
||||
for i in range(self._steps):
|
||||
edges = []
|
||||
for j in range(2 + i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
ws = weights[self.edge2index[node_str]]
|
||||
for k, op_name in enumerate(self.op_names):
|
||||
if op_name == "none":
|
||||
continue
|
||||
edges.append((op_name, j, ws[k]))
|
||||
edges = sorted(edges, key=lambda x: -x[-1])
|
||||
selected_edges = edges[:2]
|
||||
gene.append(tuple(selected_edges))
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(
|
||||
torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
gene_reduce = _parse(
|
||||
torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
return {
|
||||
"normal": gene_normal,
|
||||
"normal_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
"reduce": gene_reduce,
|
||||
"reduce_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
}
|
||||
|
||||
def forward(self, inputs):
|
||||
def get_gumbel_prob(xins):
|
||||
while True:
|
||||
gumbels = -torch.empty_like(xins).exponential_().log()
|
||||
logits = (xins.log_softmax(dim=1) + gumbels) / self.tau
|
||||
probs = nn.functional.softmax(logits, dim=1)
|
||||
index = probs.max(-1, keepdim=True)[1]
|
||||
one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0)
|
||||
hardwts = one_h - probs.detach() + probs
|
||||
if (
|
||||
(torch.isinf(gumbels).any())
|
||||
or (torch.isinf(probs).any())
|
||||
or (torch.isnan(probs).any())
|
||||
):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
return hardwts, index
|
||||
|
||||
normal_hardwts, normal_index = get_gumbel_prob(self.arch_normal_parameters)
|
||||
reduce_hardwts, reduce_index = get_gumbel_prob(self.arch_reduce_parameters)
|
||||
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if cell.reduction:
|
||||
hardwts, index = reduce_hardwts, reduce_index
|
||||
else:
|
||||
hardwts, index = normal_hardwts, normal_index
|
||||
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
|
||||
out = self.lastact(s1)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
@@ -0,0 +1,102 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##############################################################################
|
||||
# Random Search and Reproducibility for Neural Architecture Search, UAI 2019 #
|
||||
##############################################################################
|
||||
import torch, random
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
class TinyNetworkRANDOM(nn.Module):
|
||||
def __init__(
|
||||
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
|
||||
):
|
||||
super(TinyNetworkRANDOM, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
C_prev,
|
||||
C_curr,
|
||||
1,
|
||||
max_nodes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_cache = None
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def random_genotype(self, set_cache):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
op_name = random.choice(self.op_names)
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
arch = Structure(genotypes)
|
||||
if set_cache:
|
||||
self.arch_cache = arch
|
||||
return arch
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
feature = cell.forward_dynamic(feature, self.arch_cache)
|
||||
else:
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
return out, logits
|
||||
178
AutoDL-Projects/xautodl/models/cell_searchs/search_model_setn.py
Normal file
178
AutoDL-Projects/xautodl/models/cell_searchs/search_model_setn.py
Normal file
@@ -0,0 +1,178 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
######################################################################################
|
||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
|
||||
######################################################################################
|
||||
import torch, random
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from .search_cells import NAS201SearchCell as SearchCell
|
||||
from .genotypes import Structure
|
||||
|
||||
|
||||
class TinyNetworkSETN(nn.Module):
|
||||
def __init__(
|
||||
self, C, N, max_nodes, num_classes, search_space, affine, track_running_stats
|
||||
):
|
||||
super(TinyNetworkSETN, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self.max_nodes = max_nodes
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C)
|
||||
)
|
||||
|
||||
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
C_prev, num_edge, edge2index = C, None, None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||
else:
|
||||
cell = SearchCell(
|
||||
C_prev,
|
||||
C_curr,
|
||||
1,
|
||||
max_nodes,
|
||||
search_space,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev = cell.out_dim
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.mode = "urs"
|
||||
self.dynamic_cell = None
|
||||
|
||||
def set_cal_mode(self, mode, dynamic_cell=None):
|
||||
assert mode in ["urs", "joint", "select", "dynamic"]
|
||||
self.mode = mode
|
||||
if mode == "dynamic":
|
||||
self.dynamic_cell = deepcopy(dynamic_cell)
|
||||
else:
|
||||
self.dynamic_cell = None
|
||||
|
||||
def get_cal_mode(self):
|
||||
return self.mode
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_parameters]
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def genotype(self):
|
||||
genotypes = []
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
with torch.no_grad():
|
||||
weights = self.arch_parameters[self.edge2index[node_str]]
|
||||
op_name = self.op_names[weights.argmax().item()]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def dync_genotype(self, use_random=False):
|
||||
genotypes = []
|
||||
with torch.no_grad():
|
||||
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
if use_random:
|
||||
op_name = random.choice(self.op_names)
|
||||
else:
|
||||
weights = alphas_cpu[self.edge2index[node_str]]
|
||||
op_index = torch.multinomial(weights, 1).item()
|
||||
op_name = self.op_names[op_index]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def get_log_prob(self, arch):
|
||||
with torch.no_grad():
|
||||
logits = nn.functional.log_softmax(self.arch_parameters, dim=-1)
|
||||
select_logits = []
|
||||
for i, node_info in enumerate(arch.nodes):
|
||||
for op, xin in node_info:
|
||||
node_str = "{:}<-{:}".format(i + 1, xin)
|
||||
op_index = self.op_names.index(op)
|
||||
select_logits.append(logits[self.edge2index[node_str], op_index])
|
||||
return sum(select_logits).item()
|
||||
|
||||
def return_topK(self, K):
|
||||
archs = Structure.gen_all(self.op_names, self.max_nodes, False)
|
||||
pairs = [(self.get_log_prob(arch), arch) for arch in archs]
|
||||
if K < 0 or K >= len(archs):
|
||||
K = len(archs)
|
||||
sorted_pairs = sorted(pairs, key=lambda x: -x[0])
|
||||
return_pairs = [sorted_pairs[_][1] for _ in range(K)]
|
||||
return return_pairs
|
||||
|
||||
def forward(self, inputs):
|
||||
alphas = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
with torch.no_grad():
|
||||
alphas_cpu = alphas.detach().cpu()
|
||||
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
if isinstance(cell, SearchCell):
|
||||
if self.mode == "urs":
|
||||
feature = cell.forward_urs(feature)
|
||||
elif self.mode == "select":
|
||||
feature = cell.forward_select(feature, alphas_cpu)
|
||||
elif self.mode == "joint":
|
||||
feature = cell.forward_joint(feature, alphas)
|
||||
elif self.mode == "dynamic":
|
||||
feature = cell.forward_dynamic(feature, self.dynamic_cell)
|
||||
else:
|
||||
raise ValueError("invalid mode={:}".format(self.mode))
|
||||
else:
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
@@ -0,0 +1,205 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
######################################################################################
|
||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 #
|
||||
######################################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from typing import List, Text, Dict
|
||||
from .search_cells import NASNetSearchCell as SearchCell
|
||||
|
||||
|
||||
# The macro structure is based on NASNet
|
||||
class NASNetworkSETN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C: int,
|
||||
N: int,
|
||||
steps: int,
|
||||
multiplier: int,
|
||||
stem_multiplier: int,
|
||||
num_classes: int,
|
||||
search_space: List[Text],
|
||||
affine: bool,
|
||||
track_running_stats: bool,
|
||||
):
|
||||
super(NASNetworkSETN, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, C * stem_multiplier, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C * stem_multiplier),
|
||||
)
|
||||
|
||||
# config for each layer
|
||||
layer_channels = (
|
||||
[C] * N + [C * 2] + [C * 2] * (N - 1) + [C * 4] + [C * 4] * (N - 1)
|
||||
)
|
||||
layer_reductions = (
|
||||
[False] * N + [True] + [False] * (N - 1) + [True] + [False] * (N - 1)
|
||||
)
|
||||
|
||||
num_edge, edge2index = None, None
|
||||
C_prev_prev, C_prev, C_curr, reduction_prev = (
|
||||
C * stem_multiplier,
|
||||
C * stem_multiplier,
|
||||
C,
|
||||
False,
|
||||
)
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(
|
||||
zip(layer_channels, layer_reductions)
|
||||
):
|
||||
cell = SearchCell(
|
||||
search_space,
|
||||
steps,
|
||||
multiplier,
|
||||
C_prev_prev,
|
||||
C_prev,
|
||||
C_curr,
|
||||
reduction,
|
||||
reduction_prev,
|
||||
affine,
|
||||
track_running_stats,
|
||||
)
|
||||
if num_edge is None:
|
||||
num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||
else:
|
||||
assert (
|
||||
num_edge == cell.num_edges and edge2index == cell.edge2index
|
||||
), "invalid {:} vs. {:}.".format(num_edge, cell.num_edges)
|
||||
self.cells.append(cell)
|
||||
C_prev_prev, C_prev, reduction_prev = C_prev, multiplier * C_curr, reduction
|
||||
self.op_names = deepcopy(search_space)
|
||||
self._Layer = len(self.cells)
|
||||
self.edge2index = edge2index
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.arch_normal_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.arch_reduce_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(num_edge, len(search_space))
|
||||
)
|
||||
self.mode = "urs"
|
||||
self.dynamic_cell = None
|
||||
|
||||
def set_cal_mode(self, mode, dynamic_cell=None):
|
||||
assert mode in ["urs", "joint", "select", "dynamic"]
|
||||
self.mode = mode
|
||||
if mode == "dynamic":
|
||||
self.dynamic_cell = deepcopy(dynamic_cell)
|
||||
else:
|
||||
self.dynamic_cell = None
|
||||
|
||||
def get_weights(self):
|
||||
xlist = list(self.stem.parameters()) + list(self.cells.parameters())
|
||||
xlist += list(self.lastact.parameters()) + list(
|
||||
self.global_pooling.parameters()
|
||||
)
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
def get_alphas(self):
|
||||
return [self.arch_normal_parameters, self.arch_reduce_parameters]
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
A = "arch-normal-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_normal_parameters, dim=-1).cpu()
|
||||
)
|
||||
B = "arch-reduce-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self.arch_reduce_parameters, dim=-1).cpu()
|
||||
)
|
||||
return "{:}\n{:}".format(A, B)
|
||||
|
||||
def get_message(self):
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_C}, N={_layerN}, steps={_steps}, multiplier={_multiplier}, L={_Layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def dync_genotype(self, use_random=False):
|
||||
genotypes = []
|
||||
with torch.no_grad():
|
||||
alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1)
|
||||
for i in range(1, self.max_nodes):
|
||||
xlist = []
|
||||
for j in range(i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
if use_random:
|
||||
op_name = random.choice(self.op_names)
|
||||
else:
|
||||
weights = alphas_cpu[self.edge2index[node_str]]
|
||||
op_index = torch.multinomial(weights, 1).item()
|
||||
op_name = self.op_names[op_index]
|
||||
xlist.append((op_name, j))
|
||||
genotypes.append(tuple(xlist))
|
||||
return Structure(genotypes)
|
||||
|
||||
def genotype(self):
|
||||
def _parse(weights):
|
||||
gene = []
|
||||
for i in range(self._steps):
|
||||
edges = []
|
||||
for j in range(2 + i):
|
||||
node_str = "{:}<-{:}".format(i, j)
|
||||
ws = weights[self.edge2index[node_str]]
|
||||
for k, op_name in enumerate(self.op_names):
|
||||
if op_name == "none":
|
||||
continue
|
||||
edges.append((op_name, j, ws[k]))
|
||||
edges = sorted(edges, key=lambda x: -x[-1])
|
||||
selected_edges = edges[:2]
|
||||
gene.append(tuple(selected_edges))
|
||||
return gene
|
||||
|
||||
with torch.no_grad():
|
||||
gene_normal = _parse(
|
||||
torch.softmax(self.arch_normal_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
gene_reduce = _parse(
|
||||
torch.softmax(self.arch_reduce_parameters, dim=-1).cpu().numpy()
|
||||
)
|
||||
return {
|
||||
"normal": gene_normal,
|
||||
"normal_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
"reduce": gene_reduce,
|
||||
"reduce_concat": list(
|
||||
range(2 + self._steps - self._multiplier, self._steps + 2)
|
||||
),
|
||||
}
|
||||
|
||||
def forward(self, inputs):
|
||||
normal_hardwts = nn.functional.softmax(self.arch_normal_parameters, dim=-1)
|
||||
reduce_hardwts = nn.functional.softmax(self.arch_reduce_parameters, dim=-1)
|
||||
|
||||
s0 = s1 = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
# [TODO]
|
||||
raise NotImplementedError
|
||||
if cell.reduction:
|
||||
hardwts, index = reduce_hardwts, reduce_index
|
||||
else:
|
||||
hardwts, index = normal_hardwts, normal_index
|
||||
s0, s1 = s1, cell.forward_gdas(s0, s1, hardwts, index)
|
||||
out = self.lastact(s1)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
74
AutoDL-Projects/xautodl/models/clone_weights.py
Normal file
74
AutoDL-Projects/xautodl/models/clone_weights.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def copy_conv(module, init):
|
||||
assert isinstance(module, nn.Conv2d), "invalid module : {:}".format(module)
|
||||
assert isinstance(init, nn.Conv2d), "invalid module : {:}".format(init)
|
||||
new_i, new_o = module.in_channels, module.out_channels
|
||||
module.weight.copy_(init.weight.detach()[:new_o, :new_i])
|
||||
if module.bias is not None:
|
||||
module.bias.copy_(init.bias.detach()[:new_o])
|
||||
|
||||
|
||||
def copy_bn(module, init):
|
||||
assert isinstance(module, nn.BatchNorm2d), "invalid module : {:}".format(module)
|
||||
assert isinstance(init, nn.BatchNorm2d), "invalid module : {:}".format(init)
|
||||
num_features = module.num_features
|
||||
if module.weight is not None:
|
||||
module.weight.copy_(init.weight.detach()[:num_features])
|
||||
if module.bias is not None:
|
||||
module.bias.copy_(init.bias.detach()[:num_features])
|
||||
if module.running_mean is not None:
|
||||
module.running_mean.copy_(init.running_mean.detach()[:num_features])
|
||||
if module.running_var is not None:
|
||||
module.running_var.copy_(init.running_var.detach()[:num_features])
|
||||
|
||||
|
||||
def copy_fc(module, init):
|
||||
assert isinstance(module, nn.Linear), "invalid module : {:}".format(module)
|
||||
assert isinstance(init, nn.Linear), "invalid module : {:}".format(init)
|
||||
new_i, new_o = module.in_features, module.out_features
|
||||
module.weight.copy_(init.weight.detach()[:new_o, :new_i])
|
||||
if module.bias is not None:
|
||||
module.bias.copy_(init.bias.detach()[:new_o])
|
||||
|
||||
|
||||
def copy_base(module, init):
|
||||
assert type(module).__name__ in [
|
||||
"ConvBNReLU",
|
||||
"Downsample",
|
||||
], "invalid module : {:}".format(module)
|
||||
assert type(init).__name__ in [
|
||||
"ConvBNReLU",
|
||||
"Downsample",
|
||||
], "invalid module : {:}".format(init)
|
||||
if module.conv is not None:
|
||||
copy_conv(module.conv, init.conv)
|
||||
if module.bn is not None:
|
||||
copy_bn(module.bn, init.bn)
|
||||
|
||||
|
||||
def copy_basic(module, init):
|
||||
copy_base(module.conv_a, init.conv_a)
|
||||
copy_base(module.conv_b, init.conv_b)
|
||||
if module.downsample is not None:
|
||||
if init.downsample is not None:
|
||||
copy_base(module.downsample, init.downsample)
|
||||
# else:
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
|
||||
def init_from_model(network, init_model):
|
||||
with torch.no_grad():
|
||||
copy_fc(network.classifier, init_model.classifier)
|
||||
for base, target in zip(init_model.layers, network.layers):
|
||||
assert (
|
||||
type(base).__name__ == type(target).__name__
|
||||
), "invalid type : {:} vs {:}".format(base, target)
|
||||
if type(base).__name__ == "ConvBNReLU":
|
||||
copy_base(target, base)
|
||||
elif type(base).__name__ == "ResNetBasicblock":
|
||||
copy_basic(target, base)
|
||||
else:
|
||||
raise ValueError("unknown type name : {:}".format(type(base).__name__))
|
||||
16
AutoDL-Projects/xautodl/models/initialization.py
Normal file
16
AutoDL-Projects/xautodl/models/initialization.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def initialize_resnet(m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
287
AutoDL-Projects/xautodl/models/shape_infers/InferCifarResNet.py
Normal file
287
AutoDL-Projects/xautodl/models/shape_infers/InferCifarResNet.py
Normal file
@@ -0,0 +1,287 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
if has_bn:
|
||||
self.bn = nn.BatchNorm2d(nOut)
|
||||
else:
|
||||
self.bn = None
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.bn:
|
||||
out = self.bn(conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(
|
||||
iCs, list
|
||||
), "invalid type of iCs : {:}".format(iCs)
|
||||
assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs)
|
||||
|
||||
self.conv_a = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[1],
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[2],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[2]
|
||||
elif iCs[0] != iCs[2]:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[2],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
# self.out_dim = max(residual_in, iCs[2])
|
||||
self.out_dim = iCs[2]
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(
|
||||
iCs, list
|
||||
), "invalid type of iCs : {:}".format(iCs)
|
||||
assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
iCs[1],
|
||||
iCs[2],
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[3],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[3]
|
||||
elif iCs[0] != iCs[3]:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[3],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[3]
|
||||
else:
|
||||
self.downsample = None
|
||||
# self.out_dim = max(residual_in, iCs[3])
|
||||
self.out_dim = iCs[3]
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class InferCifarResNet(nn.Module):
|
||||
def __init__(
|
||||
self, block_name, depth, xblocks, xchannels, num_classes, zero_init_residual
|
||||
):
|
||||
super(InferCifarResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "ResNetBasicblock":
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == "ResNetBottleneck":
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, "depth should be one of 164"
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
assert len(xblocks) == 3, "invalid xblocks : {:}".format(xblocks)
|
||||
|
||||
self.message = (
|
||||
"InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.xchannels = xchannels
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
xchannels[0],
|
||||
xchannels[1],
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
last_channel_idx = 1
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
num_conv = block.num_conv
|
||||
iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1]
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iCs, stride)
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iCs,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
if iL + 1 == xblocks[stage]: # reach the maximum depth
|
||||
out_channel = module.out_dim
|
||||
for iiL in range(iL + 1, layer_blocks):
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
break
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -0,0 +1,263 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
if has_bn:
|
||||
self.bn = nn.BatchNorm2d(nOut)
|
||||
else:
|
||||
self.bn = None
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.bn:
|
||||
out = self.bn(conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
|
||||
self.conv_a = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
planes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
planes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes * self.expansion:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes * self.expansion
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class InferDepthCifarResNet(nn.Module):
|
||||
def __init__(self, block_name, depth, xblocks, num_classes, zero_init_residual):
|
||||
super(InferDepthCifarResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "ResNetBasicblock":
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == "ResNetBottleneck":
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, "depth should be one of 164"
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
assert len(xblocks) == 3, "invalid xblocks : {:}".format(xblocks)
|
||||
|
||||
self.message = (
|
||||
"InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
]
|
||||
)
|
||||
self.channels = [16]
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
iC = self.channels[-1]
|
||||
planes = 16 * (2 ** stage)
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iC, planes, stride)
|
||||
self.channels.append(module.out_dim)
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
planes,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
if iL + 1 == xblocks[stage]: # reach the maximum depth
|
||||
break
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(self.channels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -0,0 +1,277 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
if has_bn:
|
||||
self.bn = nn.BatchNorm2d(nOut)
|
||||
else:
|
||||
self.bn = None
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.bn:
|
||||
out = self.bn(conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(
|
||||
iCs, list
|
||||
), "invalid type of iCs : {:}".format(iCs)
|
||||
assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs)
|
||||
|
||||
self.conv_a = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[1],
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[2],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[2]
|
||||
elif iCs[0] != iCs[2]:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[2],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
# self.out_dim = max(residual_in, iCs[2])
|
||||
self.out_dim = iCs[2]
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(
|
||||
iCs, list
|
||||
), "invalid type of iCs : {:}".format(iCs)
|
||||
assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
iCs[1],
|
||||
iCs[2],
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[3],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[3]
|
||||
elif iCs[0] != iCs[3]:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[3],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[3]
|
||||
else:
|
||||
self.downsample = None
|
||||
# self.out_dim = max(residual_in, iCs[3])
|
||||
self.out_dim = iCs[3]
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class InferWidthCifarResNet(nn.Module):
|
||||
def __init__(self, block_name, depth, xchannels, num_classes, zero_init_residual):
|
||||
super(InferWidthCifarResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "ResNetBasicblock":
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == "ResNetBottleneck":
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, "depth should be one of 164"
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
|
||||
self.message = (
|
||||
"InferWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.xchannels = xchannels
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
xchannels[0],
|
||||
xchannels[1],
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
last_channel_idx = 1
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
num_conv = block.num_conv
|
||||
iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1]
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iCs, stride)
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iCs,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -0,0 +1,324 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..initialization import initialize_resnet
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
|
||||
num_conv = 1
|
||||
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
if has_bn:
|
||||
self.bn = nn.BatchNorm2d(nOut)
|
||||
else:
|
||||
self.bn = None
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.bn:
|
||||
out = self.bn(conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
num_conv = 2
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(
|
||||
iCs, list
|
||||
), "invalid type of iCs : {:}".format(iCs)
|
||||
assert len(iCs) == 3, "invalid lengths of iCs : {:}".format(iCs)
|
||||
|
||||
self.conv_a = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[1],
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
iCs[1], iCs[2], 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[2],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[2]
|
||||
elif iCs[0] != iCs[2]:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[2],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
# self.out_dim = max(residual_in, iCs[2])
|
||||
self.out_dim = iCs[2]
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + basicblock
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, iCs, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
assert isinstance(iCs, tuple) or isinstance(
|
||||
iCs, list
|
||||
), "invalid type of iCs : {:}".format(iCs)
|
||||
assert len(iCs) == 4, "invalid lengths of iCs : {:}".format(iCs)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
iCs[0], iCs[1], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
iCs[1],
|
||||
iCs[2],
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
iCs[2], iCs[3], 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
residual_in = iCs[0]
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[3],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[3]
|
||||
elif iCs[0] != iCs[3]:
|
||||
self.downsample = ConvBNReLU(
|
||||
iCs[0],
|
||||
iCs[3],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
residual_in = iCs[3]
|
||||
else:
|
||||
self.downsample = None
|
||||
# self.out_dim = max(residual_in, iCs[3])
|
||||
self.out_dim = iCs[3]
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = residual + bottleneck
|
||||
return F.relu(out, inplace=True)
|
||||
|
||||
|
||||
class InferImagenetResNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
block_name,
|
||||
layers,
|
||||
xblocks,
|
||||
xchannels,
|
||||
deep_stem,
|
||||
num_classes,
|
||||
zero_init_residual,
|
||||
):
|
||||
super(InferImagenetResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "BasicBlock":
|
||||
block = ResNetBasicblock
|
||||
elif block_name == "Bottleneck":
|
||||
block = ResNetBottleneck
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
assert len(xblocks) == len(
|
||||
layers
|
||||
), "invalid layers : {:} vs xblocks : {:}".format(layers, xblocks)
|
||||
|
||||
self.message = "InferImagenetResNet : Depth : {:} -> {:}, Layers for each block : {:}".format(
|
||||
sum(layers) * block.num_conv, sum(xblocks) * block.num_conv, xblocks
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.xchannels = xchannels
|
||||
if not deep_stem:
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
xchannels[0],
|
||||
xchannels[1],
|
||||
7,
|
||||
2,
|
||||
3,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
last_channel_idx = 1
|
||||
else:
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
xchannels[0],
|
||||
xchannels[1],
|
||||
3,
|
||||
2,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
),
|
||||
ConvBNReLU(
|
||||
xchannels[1],
|
||||
xchannels[2],
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
last_channel_idx = 2
|
||||
self.layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
||||
for stage, layer_blocks in enumerate(layers):
|
||||
for iL in range(layer_blocks):
|
||||
num_conv = block.num_conv
|
||||
iCs = self.xchannels[last_channel_idx : last_channel_idx + num_conv + 1]
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iCs, stride)
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iCs={:}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iCs,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
if iL + 1 == xblocks[stage]: # reach the maximum depth
|
||||
out_channel = module.out_dim
|
||||
for iiL in range(iL + 1, layer_blocks):
|
||||
last_channel_idx += num_conv
|
||||
self.xchannels[last_channel_idx] = module.out_dim
|
||||
break
|
||||
assert last_channel_idx + 1 == len(self.xchannels), "{:} vs {:}".format(
|
||||
last_channel_idx, len(self.xchannels)
|
||||
)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.classifier = nn.Linear(self.xchannels[-1], num_classes)
|
||||
|
||||
self.apply(initialize_resnet)
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBasicblock):
|
||||
nn.init.constant_(m.conv_b.bn.weight, 0)
|
||||
elif isinstance(m, ResNetBottleneck):
|
||||
nn.init.constant_(m.conv_1x4.bn.weight, 0)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
176
AutoDL-Projects/xautodl/models/shape_infers/InferMobileNetV2.py
Normal file
176
AutoDL-Projects/xautodl/models/shape_infers/InferMobileNetV2.py
Normal file
@@ -0,0 +1,176 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
# MobileNetV2: Inverted Residuals and Linear Bottlenecks, CVPR 2018
|
||||
#####################################################
|
||||
from torch import nn
|
||||
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import parse_channel_info
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size,
|
||||
stride,
|
||||
groups,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
padding = (kernel_size - 1) // 2
|
||||
self.conv = nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
)
|
||||
if has_bn:
|
||||
self.bn = nn.BatchNorm2d(out_planes)
|
||||
else:
|
||||
self.bn = None
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU6(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
if self.bn:
|
||||
out = self.bn(out)
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, channels, stride, expand_ratio, additive):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2], "invalid stride : {:}".format(stride)
|
||||
assert len(channels) in [2, 3], "invalid channels : {:}".format(channels)
|
||||
|
||||
if len(channels) == 2:
|
||||
layers = []
|
||||
else:
|
||||
layers = [ConvBNReLU(channels[0], channels[1], 1, 1, 1)]
|
||||
layers.extend(
|
||||
[
|
||||
# dw
|
||||
ConvBNReLU(channels[-2], channels[-2], 3, stride, channels[-2]),
|
||||
# pw-linear
|
||||
ConvBNReLU(channels[-2], channels[-1], 1, 1, 1, True, False),
|
||||
]
|
||||
)
|
||||
self.conv = nn.Sequential(*layers)
|
||||
self.additive = additive
|
||||
if self.additive and channels[0] != channels[-1]:
|
||||
self.shortcut = ConvBNReLU(channels[0], channels[-1], 1, 1, 1, True, False)
|
||||
else:
|
||||
self.shortcut = None
|
||||
self.out_dim = channels[-1]
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
# if self.additive: return additive_func(out, x)
|
||||
if self.shortcut:
|
||||
return out + self.shortcut(x)
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
class InferMobileNetV2(nn.Module):
|
||||
def __init__(self, num_classes, xchannels, xblocks, dropout):
|
||||
super(InferMobileNetV2, self).__init__()
|
||||
block = InvertedResidual
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
assert len(inverted_residual_setting) == len(
|
||||
xblocks
|
||||
), "invalid number of layers : {:} vs {:}".format(
|
||||
len(inverted_residual_setting), len(xblocks)
|
||||
)
|
||||
for block_num, ir_setting in zip(xblocks, inverted_residual_setting):
|
||||
assert block_num <= ir_setting[2], "{:} vs {:}".format(
|
||||
block_num, ir_setting
|
||||
)
|
||||
xchannels = parse_channel_info(xchannels)
|
||||
# for i, chs in enumerate(xchannels):
|
||||
# if i > 0: assert chs[0] == xchannels[i-1][-1], 'Layer[{:}] is invalid {:} vs {:}'.format(i, xchannels[i-1], chs)
|
||||
self.xchannels = xchannels
|
||||
self.message = "InferMobileNetV2 : xblocks={:}".format(xblocks)
|
||||
# building first layer
|
||||
features = [ConvBNReLU(xchannels[0][0], xchannels[0][1], 3, 2, 1)]
|
||||
last_channel_idx = 1
|
||||
|
||||
# building inverted residual blocks
|
||||
for stage, (t, c, n, s) in enumerate(inverted_residual_setting):
|
||||
for i in range(n):
|
||||
stride = s if i == 0 else 1
|
||||
additv = True if i > 0 else False
|
||||
module = block(self.xchannels[last_channel_idx], stride, t, additv)
|
||||
features.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, Cs={:}, stride={:}, expand={:}, original-C={:}".format(
|
||||
stage,
|
||||
i,
|
||||
n,
|
||||
len(features),
|
||||
self.xchannels[last_channel_idx],
|
||||
stride,
|
||||
t,
|
||||
c,
|
||||
)
|
||||
last_channel_idx += 1
|
||||
if i + 1 == xblocks[stage]:
|
||||
out_channel = module.out_dim
|
||||
for iiL in range(i + 1, n):
|
||||
last_channel_idx += 1
|
||||
self.xchannels[last_channel_idx][0] = module.out_dim
|
||||
break
|
||||
# building last several layers
|
||||
features.append(
|
||||
ConvBNReLU(
|
||||
self.xchannels[last_channel_idx][0],
|
||||
self.xchannels[last_channel_idx][1],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
)
|
||||
)
|
||||
assert last_channel_idx + 2 == len(self.xchannels), "{:} vs {:}".format(
|
||||
last_channel_idx, len(self.xchannels)
|
||||
)
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(self.xchannels[last_channel_idx][1], num_classes),
|
||||
)
|
||||
|
||||
# weight initialization
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
features = self.features(inputs)
|
||||
vectors = features.mean([2, 3])
|
||||
predicts = self.classifier(vectors)
|
||||
return features, predicts
|
||||
@@ -0,0 +1,65 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
from typing import List, Text, Any
|
||||
import torch.nn as nn
|
||||
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from ..cell_infers.cells import InferCell
|
||||
|
||||
|
||||
class DynamicShapeTinyNet(nn.Module):
|
||||
def __init__(self, channels: List[int], genotype: Any, num_classes: int):
|
||||
super(DynamicShapeTinyNet, self).__init__()
|
||||
self._channels = channels
|
||||
if len(channels) % 3 != 2:
|
||||
raise ValueError("invalid number of layers : {:}".format(len(channels)))
|
||||
self._num_stage = N = len(channels) // 3
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(channels[0]),
|
||||
)
|
||||
|
||||
# layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
c_prev = channels[0]
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (c_curr, reduction) in enumerate(zip(channels, layer_reductions)):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(c_prev, c_curr, 2, True)
|
||||
else:
|
||||
cell = InferCell(genotype, c_prev, c_curr, 1)
|
||||
self.cells.append(cell)
|
||||
c_prev = cell.out_dim
|
||||
self._num_layer = len(self.cells)
|
||||
|
||||
self.lastact = nn.Sequential(nn.BatchNorm2d(c_prev), nn.ReLU(inplace=True))
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(c_prev, num_classes)
|
||||
|
||||
def get_message(self) -> Text:
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self.cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self.cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(C={_channels}, N={_num_stage}, L={_num_layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
feature = self.stem(inputs)
|
||||
for i, cell in enumerate(self.cells):
|
||||
feature = cell(feature)
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
9
AutoDL-Projects/xautodl/models/shape_infers/__init__.py
Normal file
9
AutoDL-Projects/xautodl/models/shape_infers/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
from .InferCifarResNet_width import InferWidthCifarResNet
|
||||
from .InferImagenetResNet import InferImagenetResNet
|
||||
from .InferCifarResNet_depth import InferDepthCifarResNet
|
||||
from .InferCifarResNet import InferCifarResNet
|
||||
from .InferMobileNetV2 import InferMobileNetV2
|
||||
from .InferTinyCellNet import DynamicShapeTinyNet
|
||||
@@ -0,0 +1,5 @@
|
||||
def parse_channel_info(xstring):
|
||||
blocks = xstring.split(" ")
|
||||
blocks = [x.split("-") for x in blocks]
|
||||
blocks = [[int(_) for _ in x] for x in blocks]
|
||||
return blocks
|
||||
@@ -0,0 +1,760 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import math, torch
|
||||
from collections import OrderedDict
|
||||
from bisect import bisect_right
|
||||
import torch.nn as nn
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import additive_func
|
||||
from .SoftSelect import select2withP, ChannelWiseInter
|
||||
from .SoftSelect import linear_forward
|
||||
from .SoftSelect import get_width_choices
|
||||
|
||||
|
||||
def get_depth_choices(nDepth, return_num):
|
||||
if nDepth == 2:
|
||||
choices = (1, 2)
|
||||
elif nDepth == 3:
|
||||
choices = (1, 2, 3)
|
||||
elif nDepth > 3:
|
||||
choices = list(range(1, nDepth + 1, 2))
|
||||
if choices[-1] < nDepth:
|
||||
choices.append(nDepth)
|
||||
else:
|
||||
raise ValueError("invalid nDepth : {:}".format(nDepth))
|
||||
if return_num:
|
||||
return len(choices)
|
||||
else:
|
||||
return choices
|
||||
|
||||
|
||||
def conv_forward(inputs, conv, choices):
|
||||
iC = conv.in_channels
|
||||
fill_size = list(inputs.size())
|
||||
fill_size[1] = iC - fill_size[1]
|
||||
filled = torch.zeros(fill_size, device=inputs.device)
|
||||
xinputs = torch.cat((inputs, filled), dim=1)
|
||||
outputs = conv(xinputs)
|
||||
selecteds = [outputs[:, :oC] for oC in choices]
|
||||
return selecteds
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
num_conv = 1
|
||||
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
self.InShape = None
|
||||
self.OutShape = None
|
||||
self.choices = get_width_choices(nOut)
|
||||
self.register_buffer("choices_tensor", torch.Tensor(self.choices))
|
||||
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
# if has_bn : self.bn = nn.BatchNorm2d(nOut)
|
||||
# else : self.bn = None
|
||||
self.has_bn = has_bn
|
||||
self.BNs = nn.ModuleList()
|
||||
for i, _out in enumerate(self.choices):
|
||||
self.BNs.append(nn.BatchNorm2d(_out))
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
self.in_dim = nIn
|
||||
self.out_dim = nOut
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_flops(self, channels, check_range=True, divide=1):
|
||||
iC, oC = channels
|
||||
if check_range:
|
||||
assert (
|
||||
iC <= self.conv.in_channels and oC <= self.conv.out_channels
|
||||
), "{:} vs {:} | {:} vs {:}".format(
|
||||
iC, self.conv.in_channels, oC, self.conv.out_channels
|
||||
)
|
||||
assert (
|
||||
isinstance(self.InShape, tuple) and len(self.InShape) == 2
|
||||
), "invalid in-shape : {:}".format(self.InShape)
|
||||
assert (
|
||||
isinstance(self.OutShape, tuple) and len(self.OutShape) == 2
|
||||
), "invalid out-shape : {:}".format(self.OutShape)
|
||||
# conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
|
||||
conv_per_position_flops = (
|
||||
self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups
|
||||
)
|
||||
all_positions = self.OutShape[0] * self.OutShape[1]
|
||||
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
|
||||
if self.conv.bias is not None:
|
||||
flops += all_positions / divide
|
||||
return flops
|
||||
|
||||
def get_range(self):
|
||||
return [self.choices]
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, index, prob = tuple_inputs
|
||||
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
|
||||
probability = torch.squeeze(probability)
|
||||
assert len(index) == 2, "invalid length : {:}".format(index)
|
||||
# compute expected flop
|
||||
# coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
|
||||
expected_outC = (self.choices_tensor * probability).sum()
|
||||
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
# convolutional layer
|
||||
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
|
||||
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
|
||||
# merge
|
||||
out_channel = max([x.size(1) for x in out_bns])
|
||||
outA = ChannelWiseInter(out_bns[0], out_channel)
|
||||
outB = ChannelWiseInter(out_bns[1], out_channel)
|
||||
out = outA * prob[0] + outB * prob[1]
|
||||
# out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
|
||||
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
return out, expected_outC, expected_flop
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.has_bn:
|
||||
out = self.BNs[-1](conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
self.OutShape = (out.size(-2), out.size(-1))
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
expansion = 1
|
||||
num_conv = 2
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_a = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_range(self):
|
||||
return self.conv_a.get_range() + self.conv_b.get_range()
|
||||
|
||||
def get_flops(self, channels):
|
||||
assert len(channels) == 3, "invalid channels : {:}".format(channels)
|
||||
flop_A = self.conv_a.get_flops([channels[0], channels[1]])
|
||||
flop_B = self.conv_b.get_flops([channels[1], channels[2]])
|
||||
if hasattr(self.downsample, "get_flops"):
|
||||
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
|
||||
else:
|
||||
flop_C = 0
|
||||
if (
|
||||
channels[0] != channels[-1] and self.downsample is None
|
||||
): # this short-cut will be added during the infer-train
|
||||
flop_C = (
|
||||
channels[0]
|
||||
* channels[-1]
|
||||
* self.conv_b.OutShape[0]
|
||||
* self.conv_b.OutShape[1]
|
||||
)
|
||||
return flop_A + flop_B + flop_C
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, indexes, probs = tuple_inputs
|
||||
assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2
|
||||
out_a, expected_inC_a, expected_flop_a = self.conv_a(
|
||||
(inputs, expected_inC, probability[0], indexes[0], probs[0])
|
||||
)
|
||||
out_b, expected_inC_b, expected_flop_b = self.conv_b(
|
||||
(out_a, expected_inC_a, probability[1], indexes[1], probs[1])
|
||||
)
|
||||
if self.downsample is not None:
|
||||
residual, _, expected_flop_c = self.downsample(
|
||||
(inputs, expected_inC, probability[1], indexes[1], probs[1])
|
||||
)
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_b)
|
||||
return (
|
||||
nn.functional.relu(out, inplace=True),
|
||||
expected_inC_b,
|
||||
sum([expected_flop_a, expected_flop_b, expected_flop_c]),
|
||||
)
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, basicblock)
|
||||
return nn.functional.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
planes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
planes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes * self.expansion:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes * self.expansion
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_range(self):
|
||||
return (
|
||||
self.conv_1x1.get_range()
|
||||
+ self.conv_3x3.get_range()
|
||||
+ self.conv_1x4.get_range()
|
||||
)
|
||||
|
||||
def get_flops(self, channels):
|
||||
assert len(channels) == 4, "invalid channels : {:}".format(channels)
|
||||
flop_A = self.conv_1x1.get_flops([channels[0], channels[1]])
|
||||
flop_B = self.conv_3x3.get_flops([channels[1], channels[2]])
|
||||
flop_C = self.conv_1x4.get_flops([channels[2], channels[3]])
|
||||
if hasattr(self.downsample, "get_flops"):
|
||||
flop_D = self.downsample.get_flops([channels[0], channels[-1]])
|
||||
else:
|
||||
flop_D = 0
|
||||
if (
|
||||
channels[0] != channels[-1] and self.downsample is None
|
||||
): # this short-cut will be added during the infer-train
|
||||
flop_D = (
|
||||
channels[0]
|
||||
* channels[-1]
|
||||
* self.conv_1x4.OutShape[0]
|
||||
* self.conv_1x4.OutShape[1]
|
||||
)
|
||||
return flop_A + flop_B + flop_C + flop_D
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, bottleneck)
|
||||
return nn.functional.relu(out, inplace=True)
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, indexes, probs = tuple_inputs
|
||||
assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3
|
||||
out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1(
|
||||
(inputs, expected_inC, probability[0], indexes[0], probs[0])
|
||||
)
|
||||
out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3(
|
||||
(out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1])
|
||||
)
|
||||
out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4(
|
||||
(out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2])
|
||||
)
|
||||
if self.downsample is not None:
|
||||
residual, _, expected_flop_c = self.downsample(
|
||||
(inputs, expected_inC, probability[2], indexes[2], probs[2])
|
||||
)
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_1x4)
|
||||
return (
|
||||
nn.functional.relu(out, inplace=True),
|
||||
expected_inC_1x4,
|
||||
sum(
|
||||
[
|
||||
expected_flop_1x1,
|
||||
expected_flop_3x3,
|
||||
expected_flop_1x4,
|
||||
expected_flop_c,
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SearchShapeCifarResNet(nn.Module):
|
||||
def __init__(self, block_name, depth, num_classes):
|
||||
super(SearchShapeCifarResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "ResNetBasicblock":
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == "ResNetBottleneck":
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, "depth should be one of 164"
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
|
||||
self.message = (
|
||||
"SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.channels = [16]
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
]
|
||||
)
|
||||
self.InShape = None
|
||||
self.depth_info = OrderedDict()
|
||||
self.depth_at_i = OrderedDict()
|
||||
for stage in range(3):
|
||||
cur_block_choices = get_depth_choices(layer_blocks, False)
|
||||
assert (
|
||||
cur_block_choices[-1] == layer_blocks
|
||||
), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks)
|
||||
self.message += (
|
||||
"\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format(
|
||||
stage, cur_block_choices, layer_blocks
|
||||
)
|
||||
)
|
||||
block_choices, xstart = [], len(self.layers)
|
||||
for iL in range(layer_blocks):
|
||||
iC = self.channels[-1]
|
||||
planes = 16 * (2 ** stage)
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iC, planes, stride)
|
||||
self.channels.append(module.out_dim)
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iC,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
# added for depth
|
||||
layer_index = len(self.layers) - 1
|
||||
if iL + 1 in cur_block_choices:
|
||||
block_choices.append(layer_index)
|
||||
if iL + 1 == layer_blocks:
|
||||
self.depth_info[layer_index] = {
|
||||
"choices": block_choices,
|
||||
"stage": stage,
|
||||
"xstart": xstart,
|
||||
}
|
||||
self.depth_info_list = []
|
||||
for xend, info in self.depth_info.items():
|
||||
self.depth_info_list.append((xend, info))
|
||||
xstart, xstage = info["xstart"], info["stage"]
|
||||
for ilayer in range(xstart, xend + 1):
|
||||
idx = bisect_right(info["choices"], ilayer - 1)
|
||||
self.depth_at_i[ilayer] = (xstage, idx)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(module.out_dim, num_classes)
|
||||
self.InShape = None
|
||||
self.tau = -1
|
||||
self.search_mode = "basic"
|
||||
# assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
|
||||
|
||||
# parameters for width
|
||||
self.Ranges = []
|
||||
self.layer2indexRange = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
start_index = len(self.Ranges)
|
||||
self.Ranges += layer.get_range()
|
||||
self.layer2indexRange.append((start_index, len(self.Ranges)))
|
||||
assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format(
|
||||
len(self.Ranges) + 1, depth
|
||||
)
|
||||
|
||||
self.register_parameter(
|
||||
"width_attentions",
|
||||
nn.Parameter(torch.Tensor(len(self.Ranges), get_width_choices(None))),
|
||||
)
|
||||
self.register_parameter(
|
||||
"depth_attentions",
|
||||
nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True))),
|
||||
)
|
||||
nn.init.normal_(self.width_attentions, 0, 0.01)
|
||||
nn.init.normal_(self.depth_attentions, 0, 0.01)
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def arch_parameters(self, LR=None):
|
||||
if LR is None:
|
||||
return [self.width_attentions, self.depth_attentions]
|
||||
else:
|
||||
return [
|
||||
{"params": self.width_attentions, "lr": LR},
|
||||
{"params": self.depth_attentions, "lr": LR},
|
||||
]
|
||||
|
||||
def base_parameters(self):
|
||||
return (
|
||||
list(self.layers.parameters())
|
||||
+ list(self.avgpool.parameters())
|
||||
+ list(self.classifier.parameters())
|
||||
)
|
||||
|
||||
def get_flop(self, mode, config_dict, extra_info):
|
||||
if config_dict is not None:
|
||||
config_dict = config_dict.copy()
|
||||
# select channels
|
||||
channels = [3]
|
||||
for i, weight in enumerate(self.width_attentions):
|
||||
if mode == "genotype":
|
||||
with torch.no_grad():
|
||||
probe = nn.functional.softmax(weight, dim=0)
|
||||
C = self.Ranges[i][torch.argmax(probe).item()]
|
||||
elif mode == "max":
|
||||
C = self.Ranges[i][-1]
|
||||
elif mode == "fix":
|
||||
C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
|
||||
elif mode == "random":
|
||||
assert isinstance(extra_info, float), "invalid extra_info : {:}".format(
|
||||
extra_info
|
||||
)
|
||||
with torch.no_grad():
|
||||
prob = nn.functional.softmax(weight, dim=0)
|
||||
approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
|
||||
for j in range(prob.size(0)):
|
||||
prob[j] = 1 / (
|
||||
abs(j - (approximate_C - self.Ranges[i][j])) + 0.2
|
||||
)
|
||||
C = self.Ranges[i][torch.multinomial(prob, 1, False).item()]
|
||||
else:
|
||||
raise ValueError("invalid mode : {:}".format(mode))
|
||||
channels.append(C)
|
||||
# select depth
|
||||
if mode == "genotype":
|
||||
with torch.no_grad():
|
||||
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
|
||||
choices = torch.argmax(depth_probs, dim=1).cpu().tolist()
|
||||
elif mode == "max" or mode == "fix":
|
||||
choices = [depth_probs.size(1) - 1 for _ in range(depth_probs.size(0))]
|
||||
elif mode == "random":
|
||||
with torch.no_grad():
|
||||
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
|
||||
choices = torch.multinomial(depth_probs, 1, False).cpu().tolist()
|
||||
else:
|
||||
raise ValueError("invalid mode : {:}".format(mode))
|
||||
selected_layers = []
|
||||
for choice, xvalue in zip(choices, self.depth_info_list):
|
||||
xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1
|
||||
selected_layers.append(xtemp)
|
||||
flop = 0
|
||||
for i, layer in enumerate(self.layers):
|
||||
s, e = self.layer2indexRange[i]
|
||||
xchl = tuple(channels[s : e + 1])
|
||||
if i in self.depth_at_i:
|
||||
xstagei, xatti = self.depth_at_i[i]
|
||||
if xatti <= choices[xstagei]: # leave this depth
|
||||
flop += layer.get_flops(xchl)
|
||||
else:
|
||||
flop += 0 # do not use this layer
|
||||
else:
|
||||
flop += layer.get_flops(xchl)
|
||||
# the last fc layer
|
||||
flop += channels[-1] * self.classifier.out_features
|
||||
if config_dict is None:
|
||||
return flop / 1e6
|
||||
else:
|
||||
config_dict["xchannels"] = channels
|
||||
config_dict["xblocks"] = selected_layers
|
||||
config_dict["super_type"] = "infer-shape"
|
||||
config_dict["estimated_FLOP"] = flop / 1e6
|
||||
return flop / 1e6, config_dict
|
||||
|
||||
def get_arch_info(self):
|
||||
string = (
|
||||
"for depth and width, there are {:} + {:} attention probabilities.".format(
|
||||
len(self.depth_attentions), len(self.width_attentions)
|
||||
)
|
||||
)
|
||||
string += "\n{:}".format(self.depth_info)
|
||||
discrepancy = []
|
||||
with torch.no_grad():
|
||||
for i, att in enumerate(self.depth_attentions):
|
||||
prob = nn.functional.softmax(att, dim=0)
|
||||
prob = prob.cpu()
|
||||
selc = prob.argmax().item()
|
||||
prob = prob.tolist()
|
||||
prob = ["{:.3f}".format(x) for x in prob]
|
||||
xstring = "{:03d}/{:03d}-th : {:}".format(
|
||||
i, len(self.depth_attentions), " ".join(prob)
|
||||
)
|
||||
logt = ["{:.4f}".format(x) for x in att.cpu().tolist()]
|
||||
xstring += " || {:17s}".format(" ".join(logt))
|
||||
prob = sorted([float(x) for x in prob])
|
||||
disc = prob[-1] - prob[-2]
|
||||
xstring += " || discrepancy={:.2f} || select={:}/{:}".format(
|
||||
disc, selc, len(prob)
|
||||
)
|
||||
discrepancy.append(disc)
|
||||
string += "\n{:}".format(xstring)
|
||||
string += "\n-----------------------------------------------"
|
||||
for i, att in enumerate(self.width_attentions):
|
||||
prob = nn.functional.softmax(att, dim=0)
|
||||
prob = prob.cpu()
|
||||
selc = prob.argmax().item()
|
||||
prob = prob.tolist()
|
||||
prob = ["{:.3f}".format(x) for x in prob]
|
||||
xstring = "{:03d}/{:03d}-th : {:}".format(
|
||||
i, len(self.width_attentions), " ".join(prob)
|
||||
)
|
||||
logt = ["{:.3f}".format(x) for x in att.cpu().tolist()]
|
||||
xstring += " || {:52s}".format(" ".join(logt))
|
||||
prob = sorted([float(x) for x in prob])
|
||||
disc = prob[-1] - prob[-2]
|
||||
xstring += " || dis={:.2f} || select={:}/{:}".format(
|
||||
disc, selc, len(prob)
|
||||
)
|
||||
discrepancy.append(disc)
|
||||
string += "\n{:}".format(xstring)
|
||||
return string, discrepancy
|
||||
|
||||
def set_tau(self, tau_max, tau_min, epoch_ratio):
|
||||
assert (
|
||||
epoch_ratio >= 0 and epoch_ratio <= 1
|
||||
), "invalid epoch-ratio : {:}".format(epoch_ratio)
|
||||
tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
|
||||
self.tau = tau
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, inputs):
|
||||
flop_width_probs = nn.functional.softmax(self.width_attentions, dim=1)
|
||||
flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
|
||||
flop_depth_probs = torch.flip(
|
||||
torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1]
|
||||
)
|
||||
selected_widths, selected_width_probs = select2withP(
|
||||
self.width_attentions, self.tau
|
||||
)
|
||||
selected_depth_probs = select2withP(self.depth_attentions, self.tau, True)
|
||||
with torch.no_grad():
|
||||
selected_widths = selected_widths.cpu()
|
||||
|
||||
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
|
||||
feature_maps = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
selected_w_index = selected_widths[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
selected_w_probs = selected_width_probs[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
layer_prob = flop_width_probs[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
x, expected_inC, expected_flop = layer(
|
||||
(x, expected_inC, layer_prob, selected_w_index, selected_w_probs)
|
||||
)
|
||||
feature_maps.append(x)
|
||||
last_channel_idx += layer.num_conv
|
||||
if i in self.depth_info: # aggregate the information
|
||||
choices = self.depth_info[i]["choices"]
|
||||
xstagei = self.depth_info[i]["stage"]
|
||||
# print ('iL={:}, choices={:}, stage={:}, probs={:}'.format(i, choices, xstagei, selected_depth_probs[xstagei].cpu().tolist()))
|
||||
# for A, W in zip(choices, selected_depth_probs[xstagei]):
|
||||
# print('Size = {:}, W = {:}'.format(feature_maps[A].size(), W))
|
||||
possible_tensors = []
|
||||
max_C = max(feature_maps[A].size(1) for A in choices)
|
||||
for tempi, A in enumerate(choices):
|
||||
xtensor = ChannelWiseInter(feature_maps[A], max_C)
|
||||
# drop_ratio = 1-(tempi+1.0)/len(choices)
|
||||
# xtensor = drop_path(xtensor, drop_ratio)
|
||||
possible_tensors.append(xtensor)
|
||||
weighted_sum = sum(
|
||||
xtensor * W
|
||||
for xtensor, W in zip(
|
||||
possible_tensors, selected_depth_probs[xstagei]
|
||||
)
|
||||
)
|
||||
x = weighted_sum
|
||||
|
||||
if i in self.depth_at_i:
|
||||
xstagei, xatti = self.depth_at_i[i]
|
||||
x_expected_flop = flop_depth_probs[xstagei, xatti] * expected_flop
|
||||
else:
|
||||
x_expected_flop = expected_flop
|
||||
flops.append(x_expected_flop)
|
||||
flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6))
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = linear_forward(features, self.classifier)
|
||||
return logits, torch.stack([sum(flops)])
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -0,0 +1,515 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import math, torch
|
||||
from collections import OrderedDict
|
||||
from bisect import bisect_right
|
||||
import torch.nn as nn
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import additive_func
|
||||
from .SoftSelect import select2withP, ChannelWiseInter
|
||||
from .SoftSelect import linear_forward
|
||||
from .SoftSelect import get_width_choices
|
||||
|
||||
|
||||
def get_depth_choices(nDepth, return_num):
|
||||
if nDepth == 2:
|
||||
choices = (1, 2)
|
||||
elif nDepth == 3:
|
||||
choices = (1, 2, 3)
|
||||
elif nDepth > 3:
|
||||
choices = list(range(1, nDepth + 1, 2))
|
||||
if choices[-1] < nDepth:
|
||||
choices.append(nDepth)
|
||||
else:
|
||||
raise ValueError("invalid nDepth : {:}".format(nDepth))
|
||||
if return_num:
|
||||
return len(choices)
|
||||
else:
|
||||
return choices
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
num_conv = 1
|
||||
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
self.InShape = None
|
||||
self.OutShape = None
|
||||
self.choices = get_width_choices(nOut)
|
||||
self.register_buffer("choices_tensor", torch.Tensor(self.choices))
|
||||
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
if has_bn:
|
||||
self.bn = nn.BatchNorm2d(nOut)
|
||||
else:
|
||||
self.bn = None
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
else:
|
||||
self.relu = None
|
||||
self.in_dim = nIn
|
||||
self.out_dim = nOut
|
||||
|
||||
def get_flops(self, divide=1):
|
||||
iC, oC = self.in_dim, self.out_dim
|
||||
assert (
|
||||
iC <= self.conv.in_channels and oC <= self.conv.out_channels
|
||||
), "{:} vs {:} | {:} vs {:}".format(
|
||||
iC, self.conv.in_channels, oC, self.conv.out_channels
|
||||
)
|
||||
assert (
|
||||
isinstance(self.InShape, tuple) and len(self.InShape) == 2
|
||||
), "invalid in-shape : {:}".format(self.InShape)
|
||||
assert (
|
||||
isinstance(self.OutShape, tuple) and len(self.OutShape) == 2
|
||||
), "invalid out-shape : {:}".format(self.OutShape)
|
||||
# conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
|
||||
conv_per_position_flops = (
|
||||
self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups
|
||||
)
|
||||
all_positions = self.OutShape[0] * self.OutShape[1]
|
||||
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
|
||||
if self.conv.bias is not None:
|
||||
flops += all_positions / divide
|
||||
return flops
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.bn:
|
||||
out = self.bn(conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
self.OutShape = (out.size(-2), out.size(-1))
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
expansion = 1
|
||||
num_conv = 2
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_a = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_flops(self, divide=1):
|
||||
flop_A = self.conv_a.get_flops(divide)
|
||||
flop_B = self.conv_b.get_flops(divide)
|
||||
if hasattr(self.downsample, "get_flops"):
|
||||
flop_C = self.downsample.get_flops(divide)
|
||||
else:
|
||||
flop_C = 0
|
||||
return flop_A + flop_B + flop_C
|
||||
|
||||
def forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, basicblock)
|
||||
return nn.functional.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
planes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
planes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes * self.expansion:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes * self.expansion
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_range(self):
|
||||
return (
|
||||
self.conv_1x1.get_range()
|
||||
+ self.conv_3x3.get_range()
|
||||
+ self.conv_1x4.get_range()
|
||||
)
|
||||
|
||||
def get_flops(self, divide):
|
||||
flop_A = self.conv_1x1.get_flops(divide)
|
||||
flop_B = self.conv_3x3.get_flops(divide)
|
||||
flop_C = self.conv_1x4.get_flops(divide)
|
||||
if hasattr(self.downsample, "get_flops"):
|
||||
flop_D = self.downsample.get_flops(divide)
|
||||
else:
|
||||
flop_D = 0
|
||||
return flop_A + flop_B + flop_C + flop_D
|
||||
|
||||
def forward(self, inputs):
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, bottleneck)
|
||||
return nn.functional.relu(out, inplace=True)
|
||||
|
||||
|
||||
class SearchDepthCifarResNet(nn.Module):
|
||||
def __init__(self, block_name, depth, num_classes):
|
||||
super(SearchDepthCifarResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "ResNetBasicblock":
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == "ResNetBottleneck":
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, "depth should be one of 164"
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
|
||||
self.message = (
|
||||
"SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.channels = [16]
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
]
|
||||
)
|
||||
self.InShape = None
|
||||
self.depth_info = OrderedDict()
|
||||
self.depth_at_i = OrderedDict()
|
||||
for stage in range(3):
|
||||
cur_block_choices = get_depth_choices(layer_blocks, False)
|
||||
assert (
|
||||
cur_block_choices[-1] == layer_blocks
|
||||
), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks)
|
||||
self.message += (
|
||||
"\nstage={:} ::: depth-block-choices={:} for {:} blocks.".format(
|
||||
stage, cur_block_choices, layer_blocks
|
||||
)
|
||||
)
|
||||
block_choices, xstart = [], len(self.layers)
|
||||
for iL in range(layer_blocks):
|
||||
iC = self.channels[-1]
|
||||
planes = 16 * (2 ** stage)
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iC, planes, stride)
|
||||
self.channels.append(module.out_dim)
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iC,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
# added for depth
|
||||
layer_index = len(self.layers) - 1
|
||||
if iL + 1 in cur_block_choices:
|
||||
block_choices.append(layer_index)
|
||||
if iL + 1 == layer_blocks:
|
||||
self.depth_info[layer_index] = {
|
||||
"choices": block_choices,
|
||||
"stage": stage,
|
||||
"xstart": xstart,
|
||||
}
|
||||
self.depth_info_list = []
|
||||
for xend, info in self.depth_info.items():
|
||||
self.depth_info_list.append((xend, info))
|
||||
xstart, xstage = info["xstart"], info["stage"]
|
||||
for ilayer in range(xstart, xend + 1):
|
||||
idx = bisect_right(info["choices"], ilayer - 1)
|
||||
self.depth_at_i[ilayer] = (xstage, idx)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(module.out_dim, num_classes)
|
||||
self.InShape = None
|
||||
self.tau = -1
|
||||
self.search_mode = "basic"
|
||||
# assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
|
||||
|
||||
self.register_parameter(
|
||||
"depth_attentions",
|
||||
nn.Parameter(torch.Tensor(3, get_depth_choices(layer_blocks, True))),
|
||||
)
|
||||
nn.init.normal_(self.depth_attentions, 0, 0.01)
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def arch_parameters(self):
|
||||
return [self.depth_attentions]
|
||||
|
||||
def base_parameters(self):
|
||||
return (
|
||||
list(self.layers.parameters())
|
||||
+ list(self.avgpool.parameters())
|
||||
+ list(self.classifier.parameters())
|
||||
)
|
||||
|
||||
def get_flop(self, mode, config_dict, extra_info):
|
||||
if config_dict is not None:
|
||||
config_dict = config_dict.copy()
|
||||
# select depth
|
||||
if mode == "genotype":
|
||||
with torch.no_grad():
|
||||
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
|
||||
choices = torch.argmax(depth_probs, dim=1).cpu().tolist()
|
||||
elif mode == "max":
|
||||
choices = [depth_probs.size(1) - 1 for _ in range(depth_probs.size(0))]
|
||||
elif mode == "random":
|
||||
with torch.no_grad():
|
||||
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
|
||||
choices = torch.multinomial(depth_probs, 1, False).cpu().tolist()
|
||||
else:
|
||||
raise ValueError("invalid mode : {:}".format(mode))
|
||||
selected_layers = []
|
||||
for choice, xvalue in zip(choices, self.depth_info_list):
|
||||
xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1
|
||||
selected_layers.append(xtemp)
|
||||
flop = 0
|
||||
for i, layer in enumerate(self.layers):
|
||||
if i in self.depth_at_i:
|
||||
xstagei, xatti = self.depth_at_i[i]
|
||||
if xatti <= choices[xstagei]: # leave this depth
|
||||
flop += layer.get_flops()
|
||||
else:
|
||||
flop += 0 # do not use this layer
|
||||
else:
|
||||
flop += layer.get_flops()
|
||||
# the last fc layer
|
||||
flop += self.classifier.in_features * self.classifier.out_features
|
||||
if config_dict is None:
|
||||
return flop / 1e6
|
||||
else:
|
||||
config_dict["xblocks"] = selected_layers
|
||||
config_dict["super_type"] = "infer-depth"
|
||||
config_dict["estimated_FLOP"] = flop / 1e6
|
||||
return flop / 1e6, config_dict
|
||||
|
||||
def get_arch_info(self):
|
||||
string = "for depth, there are {:} attention probabilities.".format(
|
||||
len(self.depth_attentions)
|
||||
)
|
||||
string += "\n{:}".format(self.depth_info)
|
||||
discrepancy = []
|
||||
with torch.no_grad():
|
||||
for i, att in enumerate(self.depth_attentions):
|
||||
prob = nn.functional.softmax(att, dim=0)
|
||||
prob = prob.cpu()
|
||||
selc = prob.argmax().item()
|
||||
prob = prob.tolist()
|
||||
prob = ["{:.3f}".format(x) for x in prob]
|
||||
xstring = "{:03d}/{:03d}-th : {:}".format(
|
||||
i, len(self.depth_attentions), " ".join(prob)
|
||||
)
|
||||
logt = ["{:.4f}".format(x) for x in att.cpu().tolist()]
|
||||
xstring += " || {:17s}".format(" ".join(logt))
|
||||
prob = sorted([float(x) for x in prob])
|
||||
disc = prob[-1] - prob[-2]
|
||||
xstring += " || discrepancy={:.2f} || select={:}/{:}".format(
|
||||
disc, selc, len(prob)
|
||||
)
|
||||
discrepancy.append(disc)
|
||||
string += "\n{:}".format(xstring)
|
||||
return string, discrepancy
|
||||
|
||||
def set_tau(self, tau_max, tau_min, epoch_ratio):
|
||||
assert (
|
||||
epoch_ratio >= 0 and epoch_ratio <= 1
|
||||
), "invalid epoch-ratio : {:}".format(epoch_ratio)
|
||||
tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
|
||||
self.tau = tau
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, inputs):
|
||||
flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
|
||||
flop_depth_probs = torch.flip(
|
||||
torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1]
|
||||
)
|
||||
selected_depth_probs = select2withP(self.depth_attentions, self.tau, True)
|
||||
|
||||
x, flops = inputs, []
|
||||
feature_maps = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
layer_i = layer(x)
|
||||
feature_maps.append(layer_i)
|
||||
if i in self.depth_info: # aggregate the information
|
||||
choices = self.depth_info[i]["choices"]
|
||||
xstagei = self.depth_info[i]["stage"]
|
||||
possible_tensors = []
|
||||
for tempi, A in enumerate(choices):
|
||||
xtensor = feature_maps[A]
|
||||
possible_tensors.append(xtensor)
|
||||
weighted_sum = sum(
|
||||
xtensor * W
|
||||
for xtensor, W in zip(
|
||||
possible_tensors, selected_depth_probs[xstagei]
|
||||
)
|
||||
)
|
||||
x = weighted_sum
|
||||
else:
|
||||
x = layer_i
|
||||
|
||||
if i in self.depth_at_i:
|
||||
xstagei, xatti = self.depth_at_i[i]
|
||||
# print ('layer-{:03d}, stage={:}, att={:}, prob={:}, flop={:}'.format(i, xstagei, xatti, flop_depth_probs[xstagei, xatti].item(), layer.get_flops(1e6)))
|
||||
x_expected_flop = flop_depth_probs[xstagei, xatti] * layer.get_flops(
|
||||
1e6
|
||||
)
|
||||
else:
|
||||
x_expected_flop = layer.get_flops(1e6)
|
||||
flops.append(x_expected_flop)
|
||||
flops.append(
|
||||
(self.classifier.in_features * self.classifier.out_features * 1.0 / 1e6)
|
||||
)
|
||||
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = linear_forward(features, self.classifier)
|
||||
return logits, torch.stack([sum(flops)])
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -0,0 +1,619 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import math, torch
|
||||
import torch.nn as nn
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import additive_func
|
||||
from .SoftSelect import select2withP, ChannelWiseInter
|
||||
from .SoftSelect import linear_forward
|
||||
from .SoftSelect import get_width_choices as get_choices
|
||||
|
||||
|
||||
def conv_forward(inputs, conv, choices):
|
||||
iC = conv.in_channels
|
||||
fill_size = list(inputs.size())
|
||||
fill_size[1] = iC - fill_size[1]
|
||||
filled = torch.zeros(fill_size, device=inputs.device)
|
||||
xinputs = torch.cat((inputs, filled), dim=1)
|
||||
outputs = conv(xinputs)
|
||||
selecteds = [outputs[:, :oC] for oC in choices]
|
||||
return selecteds
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
num_conv = 1
|
||||
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
self.InShape = None
|
||||
self.OutShape = None
|
||||
self.choices = get_choices(nOut)
|
||||
self.register_buffer("choices_tensor", torch.Tensor(self.choices))
|
||||
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
# if has_bn : self.bn = nn.BatchNorm2d(nOut)
|
||||
# else : self.bn = None
|
||||
self.has_bn = has_bn
|
||||
self.BNs = nn.ModuleList()
|
||||
for i, _out in enumerate(self.choices):
|
||||
self.BNs.append(nn.BatchNorm2d(_out))
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
self.in_dim = nIn
|
||||
self.out_dim = nOut
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_flops(self, channels, check_range=True, divide=1):
|
||||
iC, oC = channels
|
||||
if check_range:
|
||||
assert (
|
||||
iC <= self.conv.in_channels and oC <= self.conv.out_channels
|
||||
), "{:} vs {:} | {:} vs {:}".format(
|
||||
iC, self.conv.in_channels, oC, self.conv.out_channels
|
||||
)
|
||||
assert (
|
||||
isinstance(self.InShape, tuple) and len(self.InShape) == 2
|
||||
), "invalid in-shape : {:}".format(self.InShape)
|
||||
assert (
|
||||
isinstance(self.OutShape, tuple) and len(self.OutShape) == 2
|
||||
), "invalid out-shape : {:}".format(self.OutShape)
|
||||
# conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
|
||||
conv_per_position_flops = (
|
||||
self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups
|
||||
)
|
||||
all_positions = self.OutShape[0] * self.OutShape[1]
|
||||
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
|
||||
if self.conv.bias is not None:
|
||||
flops += all_positions / divide
|
||||
return flops
|
||||
|
||||
def get_range(self):
|
||||
return [self.choices]
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, index, prob = tuple_inputs
|
||||
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
|
||||
probability = torch.squeeze(probability)
|
||||
assert len(index) == 2, "invalid length : {:}".format(index)
|
||||
# compute expected flop
|
||||
# coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
|
||||
expected_outC = (self.choices_tensor * probability).sum()
|
||||
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
# convolutional layer
|
||||
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
|
||||
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
|
||||
# merge
|
||||
out_channel = max([x.size(1) for x in out_bns])
|
||||
outA = ChannelWiseInter(out_bns[0], out_channel)
|
||||
outB = ChannelWiseInter(out_bns[1], out_channel)
|
||||
out = outA * prob[0] + outB * prob[1]
|
||||
# out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
|
||||
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
return out, expected_outC, expected_flop
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.has_bn:
|
||||
out = self.BNs[-1](conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
self.OutShape = (out.size(-2), out.size(-1))
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
expansion = 1
|
||||
num_conv = 2
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_a = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_range(self):
|
||||
return self.conv_a.get_range() + self.conv_b.get_range()
|
||||
|
||||
def get_flops(self, channels):
|
||||
assert len(channels) == 3, "invalid channels : {:}".format(channels)
|
||||
flop_A = self.conv_a.get_flops([channels[0], channels[1]])
|
||||
flop_B = self.conv_b.get_flops([channels[1], channels[2]])
|
||||
if hasattr(self.downsample, "get_flops"):
|
||||
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
|
||||
else:
|
||||
flop_C = 0
|
||||
if (
|
||||
channels[0] != channels[-1] and self.downsample is None
|
||||
): # this short-cut will be added during the infer-train
|
||||
flop_C = (
|
||||
channels[0]
|
||||
* channels[-1]
|
||||
* self.conv_b.OutShape[0]
|
||||
* self.conv_b.OutShape[1]
|
||||
)
|
||||
return flop_A + flop_B + flop_C
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, indexes, probs = tuple_inputs
|
||||
assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2
|
||||
out_a, expected_inC_a, expected_flop_a = self.conv_a(
|
||||
(inputs, expected_inC, probability[0], indexes[0], probs[0])
|
||||
)
|
||||
out_b, expected_inC_b, expected_flop_b = self.conv_b(
|
||||
(out_a, expected_inC_a, probability[1], indexes[1], probs[1])
|
||||
)
|
||||
if self.downsample is not None:
|
||||
residual, _, expected_flop_c = self.downsample(
|
||||
(inputs, expected_inC, probability[1], indexes[1], probs[1])
|
||||
)
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_b)
|
||||
return (
|
||||
nn.functional.relu(out, inplace=True),
|
||||
expected_inC_b,
|
||||
sum([expected_flop_a, expected_flop_b, expected_flop_c]),
|
||||
)
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, basicblock)
|
||||
return nn.functional.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
planes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
planes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes * self.expansion:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes * self.expansion
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_range(self):
|
||||
return (
|
||||
self.conv_1x1.get_range()
|
||||
+ self.conv_3x3.get_range()
|
||||
+ self.conv_1x4.get_range()
|
||||
)
|
||||
|
||||
def get_flops(self, channels):
|
||||
assert len(channels) == 4, "invalid channels : {:}".format(channels)
|
||||
flop_A = self.conv_1x1.get_flops([channels[0], channels[1]])
|
||||
flop_B = self.conv_3x3.get_flops([channels[1], channels[2]])
|
||||
flop_C = self.conv_1x4.get_flops([channels[2], channels[3]])
|
||||
if hasattr(self.downsample, "get_flops"):
|
||||
flop_D = self.downsample.get_flops([channels[0], channels[-1]])
|
||||
else:
|
||||
flop_D = 0
|
||||
if (
|
||||
channels[0] != channels[-1] and self.downsample is None
|
||||
): # this short-cut will be added during the infer-train
|
||||
flop_D = (
|
||||
channels[0]
|
||||
* channels[-1]
|
||||
* self.conv_1x4.OutShape[0]
|
||||
* self.conv_1x4.OutShape[1]
|
||||
)
|
||||
return flop_A + flop_B + flop_C + flop_D
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, bottleneck)
|
||||
return nn.functional.relu(out, inplace=True)
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, indexes, probs = tuple_inputs
|
||||
assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3
|
||||
out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1(
|
||||
(inputs, expected_inC, probability[0], indexes[0], probs[0])
|
||||
)
|
||||
out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3(
|
||||
(out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1])
|
||||
)
|
||||
out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4(
|
||||
(out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2])
|
||||
)
|
||||
if self.downsample is not None:
|
||||
residual, _, expected_flop_c = self.downsample(
|
||||
(inputs, expected_inC, probability[2], indexes[2], probs[2])
|
||||
)
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_1x4)
|
||||
return (
|
||||
nn.functional.relu(out, inplace=True),
|
||||
expected_inC_1x4,
|
||||
sum(
|
||||
[
|
||||
expected_flop_1x1,
|
||||
expected_flop_3x3,
|
||||
expected_flop_1x4,
|
||||
expected_flop_c,
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SearchWidthCifarResNet(nn.Module):
|
||||
def __init__(self, block_name, depth, num_classes):
|
||||
super(SearchWidthCifarResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "ResNetBasicblock":
|
||||
block = ResNetBasicblock
|
||||
assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110"
|
||||
layer_blocks = (depth - 2) // 6
|
||||
elif block_name == "ResNetBottleneck":
|
||||
block = ResNetBottleneck
|
||||
assert (depth - 2) % 9 == 0, "depth should be one of 164"
|
||||
layer_blocks = (depth - 2) // 9
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
|
||||
self.message = (
|
||||
"SearchWidthCifarResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.channels = [16]
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
]
|
||||
)
|
||||
self.InShape = None
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
iC = self.channels[-1]
|
||||
planes = 16 * (2 ** stage)
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iC, planes, stride)
|
||||
self.channels.append(module.out_dim)
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iC,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(module.out_dim, num_classes)
|
||||
self.InShape = None
|
||||
self.tau = -1
|
||||
self.search_mode = "basic"
|
||||
# assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
|
||||
|
||||
# parameters for width
|
||||
self.Ranges = []
|
||||
self.layer2indexRange = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
start_index = len(self.Ranges)
|
||||
self.Ranges += layer.get_range()
|
||||
self.layer2indexRange.append((start_index, len(self.Ranges)))
|
||||
assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format(
|
||||
len(self.Ranges) + 1, depth
|
||||
)
|
||||
|
||||
self.register_parameter(
|
||||
"width_attentions",
|
||||
nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))),
|
||||
)
|
||||
nn.init.normal_(self.width_attentions, 0, 0.01)
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def arch_parameters(self):
|
||||
return [self.width_attentions]
|
||||
|
||||
def base_parameters(self):
|
||||
return (
|
||||
list(self.layers.parameters())
|
||||
+ list(self.avgpool.parameters())
|
||||
+ list(self.classifier.parameters())
|
||||
)
|
||||
|
||||
def get_flop(self, mode, config_dict, extra_info):
|
||||
if config_dict is not None:
|
||||
config_dict = config_dict.copy()
|
||||
# weights = [F.softmax(x, dim=0) for x in self.width_attentions]
|
||||
channels = [3]
|
||||
for i, weight in enumerate(self.width_attentions):
|
||||
if mode == "genotype":
|
||||
with torch.no_grad():
|
||||
probe = nn.functional.softmax(weight, dim=0)
|
||||
C = self.Ranges[i][torch.argmax(probe).item()]
|
||||
elif mode == "max":
|
||||
C = self.Ranges[i][-1]
|
||||
elif mode == "fix":
|
||||
C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
|
||||
elif mode == "random":
|
||||
assert isinstance(extra_info, float), "invalid extra_info : {:}".format(
|
||||
extra_info
|
||||
)
|
||||
with torch.no_grad():
|
||||
prob = nn.functional.softmax(weight, dim=0)
|
||||
approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
|
||||
for j in range(prob.size(0)):
|
||||
prob[j] = 1 / (
|
||||
abs(j - (approximate_C - self.Ranges[i][j])) + 0.2
|
||||
)
|
||||
C = self.Ranges[i][torch.multinomial(prob, 1, False).item()]
|
||||
else:
|
||||
raise ValueError("invalid mode : {:}".format(mode))
|
||||
channels.append(C)
|
||||
flop = 0
|
||||
for i, layer in enumerate(self.layers):
|
||||
s, e = self.layer2indexRange[i]
|
||||
xchl = tuple(channels[s : e + 1])
|
||||
flop += layer.get_flops(xchl)
|
||||
# the last fc layer
|
||||
flop += channels[-1] * self.classifier.out_features
|
||||
if config_dict is None:
|
||||
return flop / 1e6
|
||||
else:
|
||||
config_dict["xchannels"] = channels
|
||||
config_dict["super_type"] = "infer-width"
|
||||
config_dict["estimated_FLOP"] = flop / 1e6
|
||||
return flop / 1e6, config_dict
|
||||
|
||||
def get_arch_info(self):
|
||||
string = "for width, there are {:} attention probabilities.".format(
|
||||
len(self.width_attentions)
|
||||
)
|
||||
discrepancy = []
|
||||
with torch.no_grad():
|
||||
for i, att in enumerate(self.width_attentions):
|
||||
prob = nn.functional.softmax(att, dim=0)
|
||||
prob = prob.cpu()
|
||||
selc = prob.argmax().item()
|
||||
prob = prob.tolist()
|
||||
prob = ["{:.3f}".format(x) for x in prob]
|
||||
xstring = "{:03d}/{:03d}-th : {:}".format(
|
||||
i, len(self.width_attentions), " ".join(prob)
|
||||
)
|
||||
logt = ["{:.3f}".format(x) for x in att.cpu().tolist()]
|
||||
xstring += " || {:52s}".format(" ".join(logt))
|
||||
prob = sorted([float(x) for x in prob])
|
||||
disc = prob[-1] - prob[-2]
|
||||
xstring += " || dis={:.2f} || select={:}/{:}".format(
|
||||
disc, selc, len(prob)
|
||||
)
|
||||
discrepancy.append(disc)
|
||||
string += "\n{:}".format(xstring)
|
||||
return string, discrepancy
|
||||
|
||||
def set_tau(self, tau_max, tau_min, epoch_ratio):
|
||||
assert (
|
||||
epoch_ratio >= 0 and epoch_ratio <= 1
|
||||
), "invalid epoch-ratio : {:}".format(epoch_ratio)
|
||||
tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
|
||||
self.tau = tau
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, inputs):
|
||||
flop_probs = nn.functional.softmax(self.width_attentions, dim=1)
|
||||
selected_widths, selected_probs = select2withP(self.width_attentions, self.tau)
|
||||
with torch.no_grad():
|
||||
selected_widths = selected_widths.cpu()
|
||||
|
||||
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
|
||||
for i, layer in enumerate(self.layers):
|
||||
selected_w_index = selected_widths[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
selected_w_probs = selected_probs[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
layer_prob = flop_probs[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
x, expected_inC, expected_flop = layer(
|
||||
(x, expected_inC, layer_prob, selected_w_index, selected_w_probs)
|
||||
)
|
||||
last_channel_idx += layer.num_conv
|
||||
flops.append(expected_flop)
|
||||
flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6))
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = linear_forward(features, self.classifier)
|
||||
return logits, torch.stack([sum(flops)])
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -0,0 +1,766 @@
|
||||
import math, torch
|
||||
from collections import OrderedDict
|
||||
from bisect import bisect_right
|
||||
import torch.nn as nn
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import additive_func
|
||||
from .SoftSelect import select2withP, ChannelWiseInter
|
||||
from .SoftSelect import linear_forward
|
||||
from .SoftSelect import get_width_choices
|
||||
|
||||
|
||||
def get_depth_choices(layers):
|
||||
min_depth = min(layers)
|
||||
info = {"num": min_depth}
|
||||
for i, depth in enumerate(layers):
|
||||
choices = []
|
||||
for j in range(1, min_depth + 1):
|
||||
choices.append(int(float(depth) * j / min_depth))
|
||||
info[i] = choices
|
||||
return info
|
||||
|
||||
|
||||
def conv_forward(inputs, conv, choices):
|
||||
iC = conv.in_channels
|
||||
fill_size = list(inputs.size())
|
||||
fill_size[1] = iC - fill_size[1]
|
||||
filled = torch.zeros(fill_size, device=inputs.device)
|
||||
xinputs = torch.cat((inputs, filled), dim=1)
|
||||
outputs = conv(xinputs)
|
||||
selecteds = [outputs[:, :oC] for oC in choices]
|
||||
return selecteds
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
num_conv = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nIn,
|
||||
nOut,
|
||||
kernel,
|
||||
stride,
|
||||
padding,
|
||||
bias,
|
||||
has_avg,
|
||||
has_bn,
|
||||
has_relu,
|
||||
last_max_pool=False,
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
self.InShape = None
|
||||
self.OutShape = None
|
||||
self.choices = get_width_choices(nOut)
|
||||
self.register_buffer("choices_tensor", torch.Tensor(self.choices))
|
||||
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
# if has_bn : self.bn = nn.BatchNorm2d(nOut)
|
||||
# else : self.bn = None
|
||||
self.has_bn = has_bn
|
||||
self.BNs = nn.ModuleList()
|
||||
for i, _out in enumerate(self.choices):
|
||||
self.BNs.append(nn.BatchNorm2d(_out))
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
|
||||
if last_max_pool:
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
else:
|
||||
self.maxpool = None
|
||||
self.in_dim = nIn
|
||||
self.out_dim = nOut
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_flops(self, channels, check_range=True, divide=1):
|
||||
iC, oC = channels
|
||||
if check_range:
|
||||
assert (
|
||||
iC <= self.conv.in_channels and oC <= self.conv.out_channels
|
||||
), "{:} vs {:} | {:} vs {:}".format(
|
||||
iC, self.conv.in_channels, oC, self.conv.out_channels
|
||||
)
|
||||
assert (
|
||||
isinstance(self.InShape, tuple) and len(self.InShape) == 2
|
||||
), "invalid in-shape : {:}".format(self.InShape)
|
||||
assert (
|
||||
isinstance(self.OutShape, tuple) and len(self.OutShape) == 2
|
||||
), "invalid out-shape : {:}".format(self.OutShape)
|
||||
# conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
|
||||
conv_per_position_flops = (
|
||||
self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups
|
||||
)
|
||||
all_positions = self.OutShape[0] * self.OutShape[1]
|
||||
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
|
||||
if self.conv.bias is not None:
|
||||
flops += all_positions / divide
|
||||
return flops
|
||||
|
||||
def get_range(self):
|
||||
return [self.choices]
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, index, prob = tuple_inputs
|
||||
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
|
||||
probability = torch.squeeze(probability)
|
||||
assert len(index) == 2, "invalid length : {:}".format(index)
|
||||
# compute expected flop
|
||||
# coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
|
||||
expected_outC = (self.choices_tensor * probability).sum()
|
||||
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
# convolutional layer
|
||||
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
|
||||
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
|
||||
# merge
|
||||
out_channel = max([x.size(1) for x in out_bns])
|
||||
outA = ChannelWiseInter(out_bns[0], out_channel)
|
||||
outB = ChannelWiseInter(out_bns[1], out_channel)
|
||||
out = outA * prob[0] + outB * prob[1]
|
||||
# out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
|
||||
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
if self.maxpool:
|
||||
out = self.maxpool(out)
|
||||
return out, expected_outC, expected_flop
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.has_bn:
|
||||
out = self.BNs[-1](conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
self.OutShape = (out.size(-2), out.size(-1))
|
||||
if self.maxpool:
|
||||
out = self.maxpool(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNetBasicblock(nn.Module):
|
||||
expansion = 1
|
||||
num_conv = 2
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBasicblock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_a = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_b = ConvBNReLU(
|
||||
planes, planes, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=False
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_range(self):
|
||||
return self.conv_a.get_range() + self.conv_b.get_range()
|
||||
|
||||
def get_flops(self, channels):
|
||||
assert len(channels) == 3, "invalid channels : {:}".format(channels)
|
||||
flop_A = self.conv_a.get_flops([channels[0], channels[1]])
|
||||
flop_B = self.conv_b.get_flops([channels[1], channels[2]])
|
||||
if hasattr(self.downsample, "get_flops"):
|
||||
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
|
||||
else:
|
||||
flop_C = 0
|
||||
if (
|
||||
channels[0] != channels[-1] and self.downsample is None
|
||||
): # this short-cut will be added during the infer-train
|
||||
flop_C = (
|
||||
channels[0]
|
||||
* channels[-1]
|
||||
* self.conv_b.OutShape[0]
|
||||
* self.conv_b.OutShape[1]
|
||||
)
|
||||
return flop_A + flop_B + flop_C
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, indexes, probs = tuple_inputs
|
||||
assert indexes.size(0) == 2 and probs.size(0) == 2 and probability.size(0) == 2
|
||||
# import pdb; pdb.set_trace()
|
||||
out_a, expected_inC_a, expected_flop_a = self.conv_a(
|
||||
(inputs, expected_inC, probability[0], indexes[0], probs[0])
|
||||
)
|
||||
out_b, expected_inC_b, expected_flop_b = self.conv_b(
|
||||
(out_a, expected_inC_a, probability[1], indexes[1], probs[1])
|
||||
)
|
||||
if self.downsample is not None:
|
||||
residual, _, expected_flop_c = self.downsample(
|
||||
(inputs, expected_inC, probability[1], indexes[1], probs[1])
|
||||
)
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_b)
|
||||
return (
|
||||
nn.functional.relu(out, inplace=True),
|
||||
expected_inC_b,
|
||||
sum([expected_flop_a, expected_flop_b, expected_flop_c]),
|
||||
)
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
basicblock = self.conv_b(basicblock)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, basicblock)
|
||||
return nn.functional.relu(out, inplace=True)
|
||||
|
||||
|
||||
class ResNetBottleneck(nn.Module):
|
||||
expansion = 4
|
||||
num_conv = 3
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(ResNetBottleneck, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv_1x1 = ConvBNReLU(
|
||||
inplanes, planes, 1, 1, 0, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
self.conv_3x3 = ConvBNReLU(
|
||||
planes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
self.conv_1x4 = ConvBNReLU(
|
||||
planes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes * self.expansion:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes * self.expansion
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_range(self):
|
||||
return (
|
||||
self.conv_1x1.get_range()
|
||||
+ self.conv_3x3.get_range()
|
||||
+ self.conv_1x4.get_range()
|
||||
)
|
||||
|
||||
def get_flops(self, channels):
|
||||
assert len(channels) == 4, "invalid channels : {:}".format(channels)
|
||||
flop_A = self.conv_1x1.get_flops([channels[0], channels[1]])
|
||||
flop_B = self.conv_3x3.get_flops([channels[1], channels[2]])
|
||||
flop_C = self.conv_1x4.get_flops([channels[2], channels[3]])
|
||||
if hasattr(self.downsample, "get_flops"):
|
||||
flop_D = self.downsample.get_flops([channels[0], channels[-1]])
|
||||
else:
|
||||
flop_D = 0
|
||||
if (
|
||||
channels[0] != channels[-1] and self.downsample is None
|
||||
): # this short-cut will be added during the infer-train
|
||||
flop_D = (
|
||||
channels[0]
|
||||
* channels[-1]
|
||||
* self.conv_1x4.OutShape[0]
|
||||
* self.conv_1x4.OutShape[1]
|
||||
)
|
||||
return flop_A + flop_B + flop_C + flop_D
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
bottleneck = self.conv_1x1(inputs)
|
||||
bottleneck = self.conv_3x3(bottleneck)
|
||||
bottleneck = self.conv_1x4(bottleneck)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, bottleneck)
|
||||
return nn.functional.relu(out, inplace=True)
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, indexes, probs = tuple_inputs
|
||||
assert indexes.size(0) == 3 and probs.size(0) == 3 and probability.size(0) == 3
|
||||
out_1x1, expected_inC_1x1, expected_flop_1x1 = self.conv_1x1(
|
||||
(inputs, expected_inC, probability[0], indexes[0], probs[0])
|
||||
)
|
||||
out_3x3, expected_inC_3x3, expected_flop_3x3 = self.conv_3x3(
|
||||
(out_1x1, expected_inC_1x1, probability[1], indexes[1], probs[1])
|
||||
)
|
||||
out_1x4, expected_inC_1x4, expected_flop_1x4 = self.conv_1x4(
|
||||
(out_3x3, expected_inC_3x3, probability[2], indexes[2], probs[2])
|
||||
)
|
||||
if self.downsample is not None:
|
||||
residual, _, expected_flop_c = self.downsample(
|
||||
(inputs, expected_inC, probability[2], indexes[2], probs[2])
|
||||
)
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_1x4)
|
||||
return (
|
||||
nn.functional.relu(out, inplace=True),
|
||||
expected_inC_1x4,
|
||||
sum(
|
||||
[
|
||||
expected_flop_1x1,
|
||||
expected_flop_3x3,
|
||||
expected_flop_1x4,
|
||||
expected_flop_c,
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SearchShapeImagenetResNet(nn.Module):
|
||||
def __init__(self, block_name, layers, deep_stem, num_classes):
|
||||
super(SearchShapeImagenetResNet, self).__init__()
|
||||
|
||||
# Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
|
||||
if block_name == "BasicBlock":
|
||||
block = ResNetBasicblock
|
||||
elif block_name == "Bottleneck":
|
||||
block = ResNetBottleneck
|
||||
else:
|
||||
raise ValueError("invalid block : {:}".format(block_name))
|
||||
|
||||
self.message = (
|
||||
"SearchShapeCifarResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
sum(layers) * block.num_conv, layers
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
if not deep_stem:
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
3,
|
||||
64,
|
||||
7,
|
||||
2,
|
||||
3,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
last_max_pool=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
self.channels = [64]
|
||||
else:
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
3, 32, 3, 2, 1, False, has_avg=False, has_bn=True, has_relu=True
|
||||
),
|
||||
ConvBNReLU(
|
||||
32,
|
||||
64,
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
last_max_pool=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
self.channels = [32, 64]
|
||||
|
||||
meta_depth_info = get_depth_choices(layers)
|
||||
self.InShape = None
|
||||
self.depth_info = OrderedDict()
|
||||
self.depth_at_i = OrderedDict()
|
||||
for stage, layer_blocks in enumerate(layers):
|
||||
cur_block_choices = meta_depth_info[stage]
|
||||
assert (
|
||||
cur_block_choices[-1] == layer_blocks
|
||||
), "stage={:}, {:} vs {:}".format(stage, cur_block_choices, layer_blocks)
|
||||
block_choices, xstart = [], len(self.layers)
|
||||
for iL in range(layer_blocks):
|
||||
iC = self.channels[-1]
|
||||
planes = 64 * (2 ** stage)
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = block(iC, planes, stride)
|
||||
self.channels.append(module.out_dim)
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iC,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
# added for depth
|
||||
layer_index = len(self.layers) - 1
|
||||
if iL + 1 in cur_block_choices:
|
||||
block_choices.append(layer_index)
|
||||
if iL + 1 == layer_blocks:
|
||||
self.depth_info[layer_index] = {
|
||||
"choices": block_choices,
|
||||
"stage": stage,
|
||||
"xstart": xstart,
|
||||
}
|
||||
self.depth_info_list = []
|
||||
for xend, info in self.depth_info.items():
|
||||
self.depth_info_list.append((xend, info))
|
||||
xstart, xstage = info["xstart"], info["stage"]
|
||||
for ilayer in range(xstart, xend + 1):
|
||||
idx = bisect_right(info["choices"], ilayer - 1)
|
||||
self.depth_at_i[ilayer] = (xstage, idx)
|
||||
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.classifier = nn.Linear(module.out_dim, num_classes)
|
||||
self.InShape = None
|
||||
self.tau = -1
|
||||
self.search_mode = "basic"
|
||||
# assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
|
||||
|
||||
# parameters for width
|
||||
self.Ranges = []
|
||||
self.layer2indexRange = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
start_index = len(self.Ranges)
|
||||
self.Ranges += layer.get_range()
|
||||
self.layer2indexRange.append((start_index, len(self.Ranges)))
|
||||
|
||||
self.register_parameter(
|
||||
"width_attentions",
|
||||
nn.Parameter(torch.Tensor(len(self.Ranges), get_width_choices(None))),
|
||||
)
|
||||
self.register_parameter(
|
||||
"depth_attentions",
|
||||
nn.Parameter(torch.Tensor(len(layers), meta_depth_info["num"])),
|
||||
)
|
||||
nn.init.normal_(self.width_attentions, 0, 0.01)
|
||||
nn.init.normal_(self.depth_attentions, 0, 0.01)
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def arch_parameters(self, LR=None):
|
||||
if LR is None:
|
||||
return [self.width_attentions, self.depth_attentions]
|
||||
else:
|
||||
return [
|
||||
{"params": self.width_attentions, "lr": LR},
|
||||
{"params": self.depth_attentions, "lr": LR},
|
||||
]
|
||||
|
||||
def base_parameters(self):
|
||||
return (
|
||||
list(self.layers.parameters())
|
||||
+ list(self.avgpool.parameters())
|
||||
+ list(self.classifier.parameters())
|
||||
)
|
||||
|
||||
def get_flop(self, mode, config_dict, extra_info):
|
||||
if config_dict is not None:
|
||||
config_dict = config_dict.copy()
|
||||
# select channels
|
||||
channels = [3]
|
||||
for i, weight in enumerate(self.width_attentions):
|
||||
if mode == "genotype":
|
||||
with torch.no_grad():
|
||||
probe = nn.functional.softmax(weight, dim=0)
|
||||
C = self.Ranges[i][torch.argmax(probe).item()]
|
||||
else:
|
||||
raise ValueError("invalid mode : {:}".format(mode))
|
||||
channels.append(C)
|
||||
# select depth
|
||||
if mode == "genotype":
|
||||
with torch.no_grad():
|
||||
depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
|
||||
choices = torch.argmax(depth_probs, dim=1).cpu().tolist()
|
||||
else:
|
||||
raise ValueError("invalid mode : {:}".format(mode))
|
||||
selected_layers = []
|
||||
for choice, xvalue in zip(choices, self.depth_info_list):
|
||||
xtemp = xvalue[1]["choices"][choice] - xvalue[1]["xstart"] + 1
|
||||
selected_layers.append(xtemp)
|
||||
flop = 0
|
||||
for i, layer in enumerate(self.layers):
|
||||
s, e = self.layer2indexRange[i]
|
||||
xchl = tuple(channels[s : e + 1])
|
||||
if i in self.depth_at_i:
|
||||
xstagei, xatti = self.depth_at_i[i]
|
||||
if xatti <= choices[xstagei]: # leave this depth
|
||||
flop += layer.get_flops(xchl)
|
||||
else:
|
||||
flop += 0 # do not use this layer
|
||||
else:
|
||||
flop += layer.get_flops(xchl)
|
||||
# the last fc layer
|
||||
flop += channels[-1] * self.classifier.out_features
|
||||
if config_dict is None:
|
||||
return flop / 1e6
|
||||
else:
|
||||
config_dict["xchannels"] = channels
|
||||
config_dict["xblocks"] = selected_layers
|
||||
config_dict["super_type"] = "infer-shape"
|
||||
config_dict["estimated_FLOP"] = flop / 1e6
|
||||
return flop / 1e6, config_dict
|
||||
|
||||
def get_arch_info(self):
|
||||
string = (
|
||||
"for depth and width, there are {:} + {:} attention probabilities.".format(
|
||||
len(self.depth_attentions), len(self.width_attentions)
|
||||
)
|
||||
)
|
||||
string += "\n{:}".format(self.depth_info)
|
||||
discrepancy = []
|
||||
with torch.no_grad():
|
||||
for i, att in enumerate(self.depth_attentions):
|
||||
prob = nn.functional.softmax(att, dim=0)
|
||||
prob = prob.cpu()
|
||||
selc = prob.argmax().item()
|
||||
prob = prob.tolist()
|
||||
prob = ["{:.3f}".format(x) for x in prob]
|
||||
xstring = "{:03d}/{:03d}-th : {:}".format(
|
||||
i, len(self.depth_attentions), " ".join(prob)
|
||||
)
|
||||
logt = ["{:.4f}".format(x) for x in att.cpu().tolist()]
|
||||
xstring += " || {:17s}".format(" ".join(logt))
|
||||
prob = sorted([float(x) for x in prob])
|
||||
disc = prob[-1] - prob[-2]
|
||||
xstring += " || discrepancy={:.2f} || select={:}/{:}".format(
|
||||
disc, selc, len(prob)
|
||||
)
|
||||
discrepancy.append(disc)
|
||||
string += "\n{:}".format(xstring)
|
||||
string += "\n-----------------------------------------------"
|
||||
for i, att in enumerate(self.width_attentions):
|
||||
prob = nn.functional.softmax(att, dim=0)
|
||||
prob = prob.cpu()
|
||||
selc = prob.argmax().item()
|
||||
prob = prob.tolist()
|
||||
prob = ["{:.3f}".format(x) for x in prob]
|
||||
xstring = "{:03d}/{:03d}-th : {:}".format(
|
||||
i, len(self.width_attentions), " ".join(prob)
|
||||
)
|
||||
logt = ["{:.3f}".format(x) for x in att.cpu().tolist()]
|
||||
xstring += " || {:52s}".format(" ".join(logt))
|
||||
prob = sorted([float(x) for x in prob])
|
||||
disc = prob[-1] - prob[-2]
|
||||
xstring += " || dis={:.2f} || select={:}/{:}".format(
|
||||
disc, selc, len(prob)
|
||||
)
|
||||
discrepancy.append(disc)
|
||||
string += "\n{:}".format(xstring)
|
||||
return string, discrepancy
|
||||
|
||||
def set_tau(self, tau_max, tau_min, epoch_ratio):
|
||||
assert (
|
||||
epoch_ratio >= 0 and epoch_ratio <= 1
|
||||
), "invalid epoch-ratio : {:}".format(epoch_ratio)
|
||||
tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
|
||||
self.tau = tau
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, inputs):
|
||||
flop_width_probs = nn.functional.softmax(self.width_attentions, dim=1)
|
||||
flop_depth_probs = nn.functional.softmax(self.depth_attentions, dim=1)
|
||||
flop_depth_probs = torch.flip(
|
||||
torch.cumsum(torch.flip(flop_depth_probs, [1]), 1), [1]
|
||||
)
|
||||
selected_widths, selected_width_probs = select2withP(
|
||||
self.width_attentions, self.tau
|
||||
)
|
||||
selected_depth_probs = select2withP(self.depth_attentions, self.tau, True)
|
||||
with torch.no_grad():
|
||||
selected_widths = selected_widths.cpu()
|
||||
|
||||
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
|
||||
feature_maps = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
selected_w_index = selected_widths[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
selected_w_probs = selected_width_probs[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
layer_prob = flop_width_probs[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
x, expected_inC, expected_flop = layer(
|
||||
(x, expected_inC, layer_prob, selected_w_index, selected_w_probs)
|
||||
)
|
||||
feature_maps.append(x)
|
||||
last_channel_idx += layer.num_conv
|
||||
if i in self.depth_info: # aggregate the information
|
||||
choices = self.depth_info[i]["choices"]
|
||||
xstagei = self.depth_info[i]["stage"]
|
||||
# print ('iL={:}, choices={:}, stage={:}, probs={:}'.format(i, choices, xstagei, selected_depth_probs[xstagei].cpu().tolist()))
|
||||
# for A, W in zip(choices, selected_depth_probs[xstagei]):
|
||||
# print('Size = {:}, W = {:}'.format(feature_maps[A].size(), W))
|
||||
possible_tensors = []
|
||||
max_C = max(feature_maps[A].size(1) for A in choices)
|
||||
for tempi, A in enumerate(choices):
|
||||
xtensor = ChannelWiseInter(feature_maps[A], max_C)
|
||||
possible_tensors.append(xtensor)
|
||||
weighted_sum = sum(
|
||||
xtensor * W
|
||||
for xtensor, W in zip(
|
||||
possible_tensors, selected_depth_probs[xstagei]
|
||||
)
|
||||
)
|
||||
x = weighted_sum
|
||||
|
||||
if i in self.depth_at_i:
|
||||
xstagei, xatti = self.depth_at_i[i]
|
||||
x_expected_flop = flop_depth_probs[xstagei, xatti] * expected_flop
|
||||
else:
|
||||
x_expected_flop = expected_flop
|
||||
flops.append(x_expected_flop)
|
||||
flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6))
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = linear_forward(features, self.classifier)
|
||||
return logits, torch.stack([sum(flops)])
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
@@ -0,0 +1,466 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import math, torch
|
||||
import torch.nn as nn
|
||||
from ..initialization import initialize_resnet
|
||||
from ..SharedUtils import additive_func
|
||||
from .SoftSelect import select2withP, ChannelWiseInter
|
||||
from .SoftSelect import linear_forward
|
||||
from .SoftSelect import get_width_choices as get_choices
|
||||
|
||||
|
||||
def conv_forward(inputs, conv, choices):
|
||||
iC = conv.in_channels
|
||||
fill_size = list(inputs.size())
|
||||
fill_size[1] = iC - fill_size[1]
|
||||
filled = torch.zeros(fill_size, device=inputs.device)
|
||||
xinputs = torch.cat((inputs, filled), dim=1)
|
||||
outputs = conv(xinputs)
|
||||
selecteds = [outputs[:, :oC] for oC in choices]
|
||||
return selecteds
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
num_conv = 1
|
||||
|
||||
def __init__(
|
||||
self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu
|
||||
):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
self.InShape = None
|
||||
self.OutShape = None
|
||||
self.choices = get_choices(nOut)
|
||||
self.register_buffer("choices_tensor", torch.Tensor(self.choices))
|
||||
|
||||
if has_avg:
|
||||
self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||
else:
|
||||
self.avg = None
|
||||
self.conv = nn.Conv2d(
|
||||
nIn,
|
||||
nOut,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
)
|
||||
# if has_bn : self.bn = nn.BatchNorm2d(nOut)
|
||||
# else : self.bn = None
|
||||
self.has_bn = has_bn
|
||||
self.BNs = nn.ModuleList()
|
||||
for i, _out in enumerate(self.choices):
|
||||
self.BNs.append(nn.BatchNorm2d(_out))
|
||||
if has_relu:
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.relu = None
|
||||
self.in_dim = nIn
|
||||
self.out_dim = nOut
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_flops(self, channels, check_range=True, divide=1):
|
||||
iC, oC = channels
|
||||
if check_range:
|
||||
assert (
|
||||
iC <= self.conv.in_channels and oC <= self.conv.out_channels
|
||||
), "{:} vs {:} | {:} vs {:}".format(
|
||||
iC, self.conv.in_channels, oC, self.conv.out_channels
|
||||
)
|
||||
assert (
|
||||
isinstance(self.InShape, tuple) and len(self.InShape) == 2
|
||||
), "invalid in-shape : {:}".format(self.InShape)
|
||||
assert (
|
||||
isinstance(self.OutShape, tuple) and len(self.OutShape) == 2
|
||||
), "invalid out-shape : {:}".format(self.OutShape)
|
||||
# conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
|
||||
conv_per_position_flops = (
|
||||
self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups
|
||||
)
|
||||
all_positions = self.OutShape[0] * self.OutShape[1]
|
||||
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
|
||||
if self.conv.bias is not None:
|
||||
flops += all_positions / divide
|
||||
return flops
|
||||
|
||||
def get_range(self):
|
||||
return [self.choices]
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, index, prob = tuple_inputs
|
||||
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
|
||||
probability = torch.squeeze(probability)
|
||||
assert len(index) == 2, "invalid length : {:}".format(index)
|
||||
# compute expected flop
|
||||
# coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
|
||||
expected_outC = (self.choices_tensor * probability).sum()
|
||||
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
# convolutional layer
|
||||
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
|
||||
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
|
||||
# merge
|
||||
out_channel = max([x.size(1) for x in out_bns])
|
||||
outA = ChannelWiseInter(out_bns[0], out_channel)
|
||||
outB = ChannelWiseInter(out_bns[1], out_channel)
|
||||
out = outA * prob[0] + outB * prob[1]
|
||||
# out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
|
||||
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
return out, expected_outC, expected_flop
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
if self.avg:
|
||||
out = self.avg(inputs)
|
||||
else:
|
||||
out = inputs
|
||||
conv = self.conv(out)
|
||||
if self.has_bn:
|
||||
out = self.BNs[-1](conv)
|
||||
else:
|
||||
out = conv
|
||||
if self.relu:
|
||||
out = self.relu(out)
|
||||
else:
|
||||
out = out
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
self.OutShape = (out.size(-2), out.size(-1))
|
||||
return out
|
||||
|
||||
|
||||
class SimBlock(nn.Module):
|
||||
expansion = 1
|
||||
num_conv = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride):
|
||||
super(SimBlock, self).__init__()
|
||||
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
|
||||
self.conv = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=True,
|
||||
)
|
||||
if stride == 2:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=True,
|
||||
has_bn=False,
|
||||
has_relu=False,
|
||||
)
|
||||
elif inplanes != planes:
|
||||
self.downsample = ConvBNReLU(
|
||||
inplanes,
|
||||
planes,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
False,
|
||||
has_avg=False,
|
||||
has_bn=True,
|
||||
has_relu=False,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.out_dim = planes
|
||||
self.search_mode = "basic"
|
||||
|
||||
def get_range(self):
|
||||
return self.conv.get_range()
|
||||
|
||||
def get_flops(self, channels):
|
||||
assert len(channels) == 2, "invalid channels : {:}".format(channels)
|
||||
flop_A = self.conv.get_flops([channels[0], channels[1]])
|
||||
if hasattr(self.downsample, "get_flops"):
|
||||
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
|
||||
else:
|
||||
flop_C = 0
|
||||
if (
|
||||
channels[0] != channels[-1] and self.downsample is None
|
||||
): # this short-cut will be added during the infer-train
|
||||
flop_C = (
|
||||
channels[0]
|
||||
* channels[-1]
|
||||
* self.conv.OutShape[0]
|
||||
* self.conv.OutShape[1]
|
||||
)
|
||||
return flop_A + flop_C
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, tuple_inputs):
|
||||
assert (
|
||||
isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5
|
||||
), "invalid type input : {:}".format(type(tuple_inputs))
|
||||
inputs, expected_inC, probability, indexes, probs = tuple_inputs
|
||||
assert (
|
||||
indexes.size(0) == 1 and probs.size(0) == 1 and probability.size(0) == 1
|
||||
), "invalid size : {:}, {:}, {:}".format(
|
||||
indexes.size(), probs.size(), probability.size()
|
||||
)
|
||||
out, expected_next_inC, expected_flop = self.conv(
|
||||
(inputs, expected_inC, probability[0], indexes[0], probs[0])
|
||||
)
|
||||
if self.downsample is not None:
|
||||
residual, _, expected_flop_c = self.downsample(
|
||||
(inputs, expected_inC, probability[-1], indexes[-1], probs[-1])
|
||||
)
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out)
|
||||
return (
|
||||
nn.functional.relu(out, inplace=True),
|
||||
expected_next_inC,
|
||||
sum([expected_flop, expected_flop_c]),
|
||||
)
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
basicblock = self.conv(inputs)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(inputs)
|
||||
else:
|
||||
residual = inputs
|
||||
out = additive_func(residual, basicblock)
|
||||
return nn.functional.relu(out, inplace=True)
|
||||
|
||||
|
||||
class SearchWidthSimResNet(nn.Module):
|
||||
def __init__(self, depth, num_classes):
|
||||
super(SearchWidthSimResNet, self).__init__()
|
||||
|
||||
assert (
|
||||
depth - 2
|
||||
) % 3 == 0, "depth should be one of 5, 8, 11, 14, ... instead of {:}".format(
|
||||
depth
|
||||
)
|
||||
layer_blocks = (depth - 2) // 3
|
||||
self.message = (
|
||||
"SearchWidthSimResNet : Depth : {:} , Layers for each block : {:}".format(
|
||||
depth, layer_blocks
|
||||
)
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.channels = [16]
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ConvBNReLU(
|
||||
3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True
|
||||
)
|
||||
]
|
||||
)
|
||||
self.InShape = None
|
||||
for stage in range(3):
|
||||
for iL in range(layer_blocks):
|
||||
iC = self.channels[-1]
|
||||
planes = 16 * (2 ** stage)
|
||||
stride = 2 if stage > 0 and iL == 0 else 1
|
||||
module = SimBlock(iC, planes, stride)
|
||||
self.channels.append(module.out_dim)
|
||||
self.layers.append(module)
|
||||
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(
|
||||
stage,
|
||||
iL,
|
||||
layer_blocks,
|
||||
len(self.layers) - 1,
|
||||
iC,
|
||||
module.out_dim,
|
||||
stride,
|
||||
)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(8)
|
||||
self.classifier = nn.Linear(module.out_dim, num_classes)
|
||||
self.InShape = None
|
||||
self.tau = -1
|
||||
self.search_mode = "basic"
|
||||
# assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
|
||||
|
||||
# parameters for width
|
||||
self.Ranges = []
|
||||
self.layer2indexRange = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
start_index = len(self.Ranges)
|
||||
self.Ranges += layer.get_range()
|
||||
self.layer2indexRange.append((start_index, len(self.Ranges)))
|
||||
assert len(self.Ranges) + 1 == depth, "invalid depth check {:} vs {:}".format(
|
||||
len(self.Ranges) + 1, depth
|
||||
)
|
||||
|
||||
self.register_parameter(
|
||||
"width_attentions",
|
||||
nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))),
|
||||
)
|
||||
nn.init.normal_(self.width_attentions, 0, 0.01)
|
||||
self.apply(initialize_resnet)
|
||||
|
||||
def arch_parameters(self):
|
||||
return [self.width_attentions]
|
||||
|
||||
def base_parameters(self):
|
||||
return (
|
||||
list(self.layers.parameters())
|
||||
+ list(self.avgpool.parameters())
|
||||
+ list(self.classifier.parameters())
|
||||
)
|
||||
|
||||
def get_flop(self, mode, config_dict, extra_info):
|
||||
if config_dict is not None:
|
||||
config_dict = config_dict.copy()
|
||||
# weights = [F.softmax(x, dim=0) for x in self.width_attentions]
|
||||
channels = [3]
|
||||
for i, weight in enumerate(self.width_attentions):
|
||||
if mode == "genotype":
|
||||
with torch.no_grad():
|
||||
probe = nn.functional.softmax(weight, dim=0)
|
||||
C = self.Ranges[i][torch.argmax(probe).item()]
|
||||
elif mode == "max":
|
||||
C = self.Ranges[i][-1]
|
||||
elif mode == "fix":
|
||||
C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
|
||||
elif mode == "random":
|
||||
assert isinstance(extra_info, float), "invalid extra_info : {:}".format(
|
||||
extra_info
|
||||
)
|
||||
with torch.no_grad():
|
||||
prob = nn.functional.softmax(weight, dim=0)
|
||||
approximate_C = int(math.sqrt(extra_info) * self.Ranges[i][-1])
|
||||
for j in range(prob.size(0)):
|
||||
prob[j] = 1 / (
|
||||
abs(j - (approximate_C - self.Ranges[i][j])) + 0.2
|
||||
)
|
||||
C = self.Ranges[i][torch.multinomial(prob, 1, False).item()]
|
||||
else:
|
||||
raise ValueError("invalid mode : {:}".format(mode))
|
||||
channels.append(C)
|
||||
flop = 0
|
||||
for i, layer in enumerate(self.layers):
|
||||
s, e = self.layer2indexRange[i]
|
||||
xchl = tuple(channels[s : e + 1])
|
||||
flop += layer.get_flops(xchl)
|
||||
# the last fc layer
|
||||
flop += channels[-1] * self.classifier.out_features
|
||||
if config_dict is None:
|
||||
return flop / 1e6
|
||||
else:
|
||||
config_dict["xchannels"] = channels
|
||||
config_dict["super_type"] = "infer-width"
|
||||
config_dict["estimated_FLOP"] = flop / 1e6
|
||||
return flop / 1e6, config_dict
|
||||
|
||||
def get_arch_info(self):
|
||||
string = "for width, there are {:} attention probabilities.".format(
|
||||
len(self.width_attentions)
|
||||
)
|
||||
discrepancy = []
|
||||
with torch.no_grad():
|
||||
for i, att in enumerate(self.width_attentions):
|
||||
prob = nn.functional.softmax(att, dim=0)
|
||||
prob = prob.cpu()
|
||||
selc = prob.argmax().item()
|
||||
prob = prob.tolist()
|
||||
prob = ["{:.3f}".format(x) for x in prob]
|
||||
xstring = "{:03d}/{:03d}-th : {:}".format(
|
||||
i, len(self.width_attentions), " ".join(prob)
|
||||
)
|
||||
logt = ["{:.3f}".format(x) for x in att.cpu().tolist()]
|
||||
xstring += " || {:52s}".format(" ".join(logt))
|
||||
prob = sorted([float(x) for x in prob])
|
||||
disc = prob[-1] - prob[-2]
|
||||
xstring += " || dis={:.2f} || select={:}/{:}".format(
|
||||
disc, selc, len(prob)
|
||||
)
|
||||
discrepancy.append(disc)
|
||||
string += "\n{:}".format(xstring)
|
||||
return string, discrepancy
|
||||
|
||||
def set_tau(self, tau_max, tau_min, epoch_ratio):
|
||||
assert (
|
||||
epoch_ratio >= 0 and epoch_ratio <= 1
|
||||
), "invalid epoch-ratio : {:}".format(epoch_ratio)
|
||||
tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
|
||||
self.tau = tau
|
||||
|
||||
def get_message(self):
|
||||
return self.message
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.search_mode == "basic":
|
||||
return self.basic_forward(inputs)
|
||||
elif self.search_mode == "search":
|
||||
return self.search_forward(inputs)
|
||||
else:
|
||||
raise ValueError("invalid search_mode = {:}".format(self.search_mode))
|
||||
|
||||
def search_forward(self, inputs):
|
||||
flop_probs = nn.functional.softmax(self.width_attentions, dim=1)
|
||||
selected_widths, selected_probs = select2withP(self.width_attentions, self.tau)
|
||||
with torch.no_grad():
|
||||
selected_widths = selected_widths.cpu()
|
||||
|
||||
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
|
||||
for i, layer in enumerate(self.layers):
|
||||
selected_w_index = selected_widths[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
selected_w_probs = selected_probs[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
layer_prob = flop_probs[
|
||||
last_channel_idx : last_channel_idx + layer.num_conv
|
||||
]
|
||||
x, expected_inC, expected_flop = layer(
|
||||
(x, expected_inC, layer_prob, selected_w_index, selected_w_probs)
|
||||
)
|
||||
last_channel_idx += layer.num_conv
|
||||
flops.append(expected_flop)
|
||||
flops.append(expected_inC * (self.classifier.out_features * 1.0 / 1e6))
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = linear_forward(features, self.classifier)
|
||||
return logits, torch.stack([sum(flops)])
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
if self.InShape is None:
|
||||
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||
x = inputs
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
features = self.avgpool(x)
|
||||
features = features.view(features.size(0), -1)
|
||||
logits = self.classifier(features)
|
||||
return features, logits
|
||||
128
AutoDL-Projects/xautodl/models/shape_searchs/SoftSelect.py
Normal file
128
AutoDL-Projects/xautodl/models/shape_searchs/SoftSelect.py
Normal file
@@ -0,0 +1,128 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import math, torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7):
|
||||
if tau <= 0:
|
||||
new_logits = logits
|
||||
probs = nn.functional.softmax(new_logits, dim=1)
|
||||
else:
|
||||
while True: # a trick to avoid the gumbels bug
|
||||
gumbels = -torch.empty_like(logits).exponential_().log()
|
||||
new_logits = (logits.log_softmax(dim=1) + gumbels) / tau
|
||||
probs = nn.functional.softmax(new_logits, dim=1)
|
||||
if (
|
||||
(not torch.isinf(gumbels).any())
|
||||
and (not torch.isinf(probs).any())
|
||||
and (not torch.isnan(probs).any())
|
||||
):
|
||||
break
|
||||
|
||||
if just_prob:
|
||||
return probs
|
||||
|
||||
# with torch.no_grad(): # add eps for unexpected torch error
|
||||
# probs = nn.functional.softmax(new_logits, dim=1)
|
||||
# selected_index = torch.multinomial(probs + eps, 2, False)
|
||||
with torch.no_grad(): # add eps for unexpected torch error
|
||||
probs = probs.cpu()
|
||||
selected_index = torch.multinomial(probs + eps, num, False).to(logits.device)
|
||||
selected_logit = torch.gather(new_logits, 1, selected_index)
|
||||
selcted_probs = nn.functional.softmax(selected_logit, dim=1)
|
||||
return selected_index, selcted_probs
|
||||
|
||||
|
||||
def ChannelWiseInter(inputs, oC, mode="v2"):
|
||||
if mode == "v1":
|
||||
return ChannelWiseInterV1(inputs, oC)
|
||||
elif mode == "v2":
|
||||
return ChannelWiseInterV2(inputs, oC)
|
||||
else:
|
||||
raise ValueError("invalid mode : {:}".format(mode))
|
||||
|
||||
|
||||
def ChannelWiseInterV1(inputs, oC):
|
||||
assert inputs.dim() == 4, "invalid dimension : {:}".format(inputs.size())
|
||||
|
||||
def start_index(a, b, c):
|
||||
return int(math.floor(float(a * c) / b))
|
||||
|
||||
def end_index(a, b, c):
|
||||
return int(math.ceil(float((a + 1) * c) / b))
|
||||
|
||||
batch, iC, H, W = inputs.size()
|
||||
outputs = torch.zeros((batch, oC, H, W), dtype=inputs.dtype, device=inputs.device)
|
||||
if iC == oC:
|
||||
return inputs
|
||||
for ot in range(oC):
|
||||
istartT, iendT = start_index(ot, oC, iC), end_index(ot, oC, iC)
|
||||
values = inputs[:, istartT:iendT].mean(dim=1)
|
||||
outputs[:, ot, :, :] = values
|
||||
return outputs
|
||||
|
||||
|
||||
def ChannelWiseInterV2(inputs, oC):
|
||||
assert inputs.dim() == 4, "invalid dimension : {:}".format(inputs.size())
|
||||
batch, C, H, W = inputs.size()
|
||||
if C == oC:
|
||||
return inputs
|
||||
else:
|
||||
return nn.functional.adaptive_avg_pool3d(inputs, (oC, H, W))
|
||||
# inputs_5D = inputs.view(batch, 1, C, H, W)
|
||||
# otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'area', None)
|
||||
# otputs = otputs_5D.view(batch, oC, H, W)
|
||||
# otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'trilinear', False)
|
||||
# return otputs
|
||||
|
||||
|
||||
def linear_forward(inputs, linear):
|
||||
if linear is None:
|
||||
return inputs
|
||||
iC = inputs.size(1)
|
||||
weight = linear.weight[:, :iC]
|
||||
if linear.bias is None:
|
||||
bias = None
|
||||
else:
|
||||
bias = linear.bias
|
||||
return nn.functional.linear(inputs, weight, bias)
|
||||
|
||||
|
||||
def get_width_choices(nOut):
|
||||
xsrange = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
|
||||
if nOut is None:
|
||||
return len(xsrange)
|
||||
else:
|
||||
Xs = [int(nOut * i) for i in xsrange]
|
||||
# xs = [ int(nOut * i // 10) for i in range(2, 11)]
|
||||
# Xs = [x for i, x in enumerate(xs) if i+1 == len(xs) or xs[i+1] > x+1]
|
||||
Xs = sorted(list(set(Xs)))
|
||||
return tuple(Xs)
|
||||
|
||||
|
||||
def get_depth_choices(nDepth):
|
||||
if nDepth is None:
|
||||
return 3
|
||||
else:
|
||||
assert nDepth >= 3, "nDepth should be greater than 2 vs {:}".format(nDepth)
|
||||
if nDepth == 1:
|
||||
return (1, 1, 1)
|
||||
elif nDepth == 2:
|
||||
return (1, 1, 2)
|
||||
elif nDepth >= 3:
|
||||
return (nDepth // 3, nDepth * 2 // 3, nDepth)
|
||||
else:
|
||||
raise ValueError("invalid Depth : {:}".format(nDepth))
|
||||
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
if drop_prob > 0.0:
|
||||
keep_prob = 1.0 - drop_prob
|
||||
mask = x.new_zeros(x.size(0), 1, 1, 1)
|
||||
mask = mask.bernoulli_(keep_prob)
|
||||
x = x * (mask / keep_prob)
|
||||
# x.div_(keep_prob)
|
||||
# x.mul_(mask)
|
||||
return x
|
||||
9
AutoDL-Projects/xautodl/models/shape_searchs/__init__.py
Normal file
9
AutoDL-Projects/xautodl/models/shape_searchs/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
from .SearchCifarResNet_width import SearchWidthCifarResNet
|
||||
from .SearchCifarResNet_depth import SearchDepthCifarResNet
|
||||
from .SearchCifarResNet import SearchShapeCifarResNet
|
||||
from .SearchSimResNet_width import SearchWidthSimResNet
|
||||
from .SearchImagenetResNet import SearchShapeImagenetResNet
|
||||
from .generic_size_tiny_cell_model import GenericNAS301Model
|
||||
@@ -0,0 +1,209 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
# Here, we utilized three techniques to search for the number of channels:
|
||||
# - channel-wise interpolation from "Network Pruning via Transformable Architecture Search, NeurIPS 2019"
|
||||
# - masking + Gumbel-Softmax (mask_gumbel) from "FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions, CVPR 2020"
|
||||
# - masking + sampling (mask_rl) from "Can Weight Sharing Outperform Random Architecture Search? An Investigation With TuNAS, CVPR 2020"
|
||||
from typing import List, Text, Any
|
||||
import random, torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from ..cell_infers.cells import InferCell
|
||||
from .SoftSelect import select2withP, ChannelWiseInter
|
||||
|
||||
|
||||
class GenericNAS301Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
candidate_Cs: List[int],
|
||||
max_num_Cs: int,
|
||||
genotype: Any,
|
||||
num_classes: int,
|
||||
affine: bool,
|
||||
track_running_stats: bool,
|
||||
):
|
||||
super(GenericNAS301Model, self).__init__()
|
||||
self._max_num_Cs = max_num_Cs
|
||||
self._candidate_Cs = candidate_Cs
|
||||
if max_num_Cs % 3 != 2:
|
||||
raise ValueError("invalid number of layers : {:}".format(max_num_Cs))
|
||||
self._num_stage = N = max_num_Cs // 3
|
||||
self._max_C = max(candidate_Cs)
|
||||
|
||||
stem = nn.Sequential(
|
||||
nn.Conv2d(3, self._max_C, kernel_size=3, padding=1, bias=not affine),
|
||||
nn.BatchNorm2d(
|
||||
self._max_C, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
)
|
||||
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
|
||||
c_prev = self._max_C
|
||||
self._cells = nn.ModuleList()
|
||||
self._cells.append(stem)
|
||||
for index, reduction in enumerate(layer_reductions):
|
||||
if reduction:
|
||||
cell = ResNetBasicblock(c_prev, self._max_C, 2, True)
|
||||
else:
|
||||
cell = InferCell(
|
||||
genotype, c_prev, self._max_C, 1, affine, track_running_stats
|
||||
)
|
||||
self._cells.append(cell)
|
||||
c_prev = cell.out_dim
|
||||
self._num_layer = len(self._cells)
|
||||
|
||||
self.lastact = nn.Sequential(
|
||||
nn.BatchNorm2d(
|
||||
c_prev, affine=affine, track_running_stats=track_running_stats
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(c_prev, num_classes)
|
||||
# algorithm related
|
||||
self.register_buffer("_tau", torch.zeros(1))
|
||||
self._algo = None
|
||||
self._warmup_ratio = None
|
||||
|
||||
def set_algo(self, algo: Text):
|
||||
# used for searching
|
||||
assert self._algo is None, "This functioin can only be called once."
|
||||
assert algo in ["mask_gumbel", "mask_rl", "tas"], "invalid algo : {:}".format(
|
||||
algo
|
||||
)
|
||||
self._algo = algo
|
||||
self._arch_parameters = nn.Parameter(
|
||||
1e-3 * torch.randn(self._max_num_Cs, len(self._candidate_Cs))
|
||||
)
|
||||
# if algo == 'mask_gumbel' or algo == 'mask_rl':
|
||||
self.register_buffer(
|
||||
"_masks", torch.zeros(len(self._candidate_Cs), max(self._candidate_Cs))
|
||||
)
|
||||
for i in range(len(self._candidate_Cs)):
|
||||
self._masks.data[i, : self._candidate_Cs[i]] = 1
|
||||
|
||||
@property
|
||||
def tau(self):
|
||||
return self._tau
|
||||
|
||||
def set_tau(self, tau):
|
||||
self._tau.data[:] = tau
|
||||
|
||||
@property
|
||||
def warmup_ratio(self):
|
||||
return self._warmup_ratio
|
||||
|
||||
def set_warmup_ratio(self, ratio: float):
|
||||
self._warmup_ratio = ratio
|
||||
|
||||
@property
|
||||
def weights(self):
|
||||
xlist = list(self._cells.parameters())
|
||||
xlist += list(self.lastact.parameters())
|
||||
xlist += list(self.global_pooling.parameters())
|
||||
xlist += list(self.classifier.parameters())
|
||||
return xlist
|
||||
|
||||
@property
|
||||
def alphas(self):
|
||||
return [self._arch_parameters]
|
||||
|
||||
def show_alphas(self):
|
||||
with torch.no_grad():
|
||||
return "arch-parameters :\n{:}".format(
|
||||
nn.functional.softmax(self._arch_parameters, dim=-1).cpu()
|
||||
)
|
||||
|
||||
@property
|
||||
def random(self):
|
||||
cs = []
|
||||
for i in range(self._max_num_Cs):
|
||||
index = random.randint(0, len(self._candidate_Cs) - 1)
|
||||
cs.append(str(self._candidate_Cs[index]))
|
||||
return ":".join(cs)
|
||||
|
||||
@property
|
||||
def genotype(self):
|
||||
cs = []
|
||||
for i in range(self._max_num_Cs):
|
||||
with torch.no_grad():
|
||||
index = self._arch_parameters[i].argmax().item()
|
||||
cs.append(str(self._candidate_Cs[index]))
|
||||
return ":".join(cs)
|
||||
|
||||
def get_message(self) -> Text:
|
||||
string = self.extra_repr()
|
||||
for i, cell in enumerate(self._cells):
|
||||
string += "\n {:02d}/{:02d} :: {:}".format(
|
||||
i, len(self._cells), cell.extra_repr()
|
||||
)
|
||||
return string
|
||||
|
||||
def extra_repr(self):
|
||||
return "{name}(candidates={_candidate_Cs}, num={_max_num_Cs}, N={_num_stage}, L={_num_layer})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
feature = inputs
|
||||
|
||||
log_probs = []
|
||||
for i, cell in enumerate(self._cells):
|
||||
feature = cell(feature)
|
||||
# apply different searching algorithms
|
||||
idx = max(0, i - 1)
|
||||
if self._warmup_ratio is not None:
|
||||
if random.random() < self._warmup_ratio:
|
||||
mask = self._masks[-1]
|
||||
else:
|
||||
mask = self._masks[random.randint(0, len(self._masks) - 1)]
|
||||
feature = feature * mask.view(1, -1, 1, 1)
|
||||
elif self._algo == "mask_gumbel":
|
||||
weights = nn.functional.gumbel_softmax(
|
||||
self._arch_parameters[idx : idx + 1], tau=self.tau, dim=-1
|
||||
)
|
||||
mask = torch.matmul(weights, self._masks).view(1, -1, 1, 1)
|
||||
feature = feature * mask
|
||||
elif self._algo == "tas":
|
||||
selected_cs, selected_probs = select2withP(
|
||||
self._arch_parameters[idx : idx + 1], self.tau, num=2
|
||||
)
|
||||
with torch.no_grad():
|
||||
i1, i2 = selected_cs.cpu().view(-1).tolist()
|
||||
c1, c2 = self._candidate_Cs[i1], self._candidate_Cs[i2]
|
||||
out_channel = max(c1, c2)
|
||||
out1 = ChannelWiseInter(feature[:, :c1], out_channel)
|
||||
out2 = ChannelWiseInter(feature[:, :c2], out_channel)
|
||||
out = out1 * selected_probs[0, 0] + out2 * selected_probs[0, 1]
|
||||
if feature.shape[1] == out.shape[1]:
|
||||
feature = out
|
||||
else:
|
||||
miss = torch.zeros(
|
||||
feature.shape[0],
|
||||
feature.shape[1] - out.shape[1],
|
||||
feature.shape[2],
|
||||
feature.shape[3],
|
||||
device=feature.device,
|
||||
)
|
||||
feature = torch.cat((out, miss), dim=1)
|
||||
elif self._algo == "mask_rl":
|
||||
prob = nn.functional.softmax(
|
||||
self._arch_parameters[idx : idx + 1], dim=-1
|
||||
)
|
||||
dist = torch.distributions.Categorical(prob)
|
||||
action = dist.sample()
|
||||
log_probs.append(dist.log_prob(action))
|
||||
mask = self._masks[action.item()].view(1, -1, 1, 1)
|
||||
feature = feature * mask
|
||||
else:
|
||||
raise ValueError("invalid algorithm : {:}".format(self._algo))
|
||||
|
||||
out = self.lastact(feature)
|
||||
out = self.global_pooling(out)
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits, log_probs
|
||||
76
AutoDL-Projects/xautodl/nas_infer_model/DXYs/CifarNet.py
Normal file
76
AutoDL-Projects/xautodl/nas_infer_model/DXYs/CifarNet.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .construct_utils import drop_path
|
||||
from .head_utils import CifarHEAD, AuxiliaryHeadCIFAR
|
||||
from .base_cells import InferCell
|
||||
|
||||
|
||||
class NetworkCIFAR(nn.Module):
|
||||
|
||||
def __init__(self, C, N, stem_multiplier, auxiliary, genotype, num_classes):
|
||||
super(NetworkCIFAR, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
self._stem_multiplier = stem_multiplier
|
||||
|
||||
C_curr = self._stem_multiplier * C
|
||||
self.stem = CifarHEAD(C_curr)
|
||||
|
||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
block_indexs = [0 ] * N + [-1 ] + [1 ] * N + [-1 ] + [2 ] * N
|
||||
block2index = {0:[], 1:[], 2:[]}
|
||||
|
||||
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||
reduction_prev, spatial, dims = False, 1, []
|
||||
self.auxiliary_index = None
|
||||
self.auxiliary_head = None
|
||||
self.cells = nn.ModuleList()
|
||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||
cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells.append( cell )
|
||||
C_prev_prev, C_prev = C_prev, cell._multiplier*C_curr
|
||||
if reduction and C_curr == C*4:
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadCIFAR(C_prev, num_classes)
|
||||
self.auxiliary_index = index
|
||||
|
||||
if reduction: spatial *= 2
|
||||
dims.append( (C_prev, spatial) )
|
||||
|
||||
self._Layer= len(self.cells)
|
||||
|
||||
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.drop_path_prob = -1
|
||||
|
||||
def update_drop_path(self, drop_path_prob):
|
||||
self.drop_path_prob = drop_path_prob
|
||||
|
||||
def auxiliary_param(self):
|
||||
if self.auxiliary_head is None: return []
|
||||
else: return list( self.auxiliary_head.parameters() )
|
||||
|
||||
def get_message(self):
|
||||
return self.extra_repr()
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(C={_C}, N={_layerN}, L={_Layer}, stem={_stem_multiplier}, drop-path={drop_path_prob})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def forward(self, inputs):
|
||||
stem_feature, logits_aux = self.stem(inputs), None
|
||||
cell_results = [stem_feature, stem_feature]
|
||||
for i, cell in enumerate(self.cells):
|
||||
cell_feature = cell(cell_results[-2], cell_results[-1], self.drop_path_prob)
|
||||
cell_results.append( cell_feature )
|
||||
|
||||
if self.auxiliary_index is not None and i == self.auxiliary_index and self.training:
|
||||
logits_aux = self.auxiliary_head( cell_results[-1] )
|
||||
out = self.global_pooling( cell_results[-1] )
|
||||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
if logits_aux is None: return out, logits
|
||||
else : return out, [logits, logits_aux]
|
||||
77
AutoDL-Projects/xautodl/nas_infer_model/DXYs/ImageNet.py
Normal file
77
AutoDL-Projects/xautodl/nas_infer_model/DXYs/ImageNet.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .construct_utils import drop_path
|
||||
from .base_cells import InferCell
|
||||
from .head_utils import ImageNetHEAD, AuxiliaryHeadImageNet
|
||||
|
||||
|
||||
class NetworkImageNet(nn.Module):
|
||||
|
||||
def __init__(self, C, N, auxiliary, genotype, num_classes):
|
||||
super(NetworkImageNet, self).__init__()
|
||||
self._C = C
|
||||
self._layerN = N
|
||||
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4] * N
|
||||
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||
self.stem0 = nn.Sequential(
|
||||
nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
self.stem1 = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(C),
|
||||
)
|
||||
|
||||
C_prev_prev, C_prev, C_curr, reduction_prev = C, C, C, True
|
||||
|
||||
self.cells = nn.ModuleList()
|
||||
self.auxiliary_index = None
|
||||
for i, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||
cell = InferCell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||
reduction_prev = reduction
|
||||
self.cells += [cell]
|
||||
C_prev_prev, C_prev = C_prev, cell._multiplier * C_curr
|
||||
if reduction and C_curr == C*4:
|
||||
C_to_auxiliary = C_prev
|
||||
self.auxiliary_index = i
|
||||
|
||||
self._NNN = len(self.cells)
|
||||
if auxiliary:
|
||||
self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
|
||||
else:
|
||||
self.auxiliary_head = None
|
||||
self.global_pooling = nn.AvgPool2d(7)
|
||||
self.classifier = nn.Linear(C_prev, num_classes)
|
||||
self.drop_path_prob = -1
|
||||
|
||||
def update_drop_path(self, drop_path_prob):
|
||||
self.drop_path_prob = drop_path_prob
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(C={_C}, N=[{_layerN}, {_NNN}], aux-index={auxiliary_index}, drop-path={drop_path_prob})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def get_message(self):
|
||||
return self.extra_repr()
|
||||
|
||||
def auxiliary_param(self):
|
||||
if self.auxiliary_head is None: return []
|
||||
else: return list( self.auxiliary_head.parameters() )
|
||||
|
||||
def forward(self, inputs):
|
||||
s0 = self.stem0(inputs)
|
||||
s1 = self.stem1(s0)
|
||||
logits_aux = None
|
||||
for i, cell in enumerate(self.cells):
|
||||
s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
|
||||
if i == self.auxiliary_index and self.auxiliary_head and self.training:
|
||||
logits_aux = self.auxiliary_head(s1)
|
||||
out = self.global_pooling(s1)
|
||||
logits = self.classifier(out.view(out.size(0), -1))
|
||||
|
||||
if logits_aux is None: return out, logits
|
||||
else : return out, [logits, logits_aux]
|
||||
5
AutoDL-Projects/xautodl/nas_infer_model/DXYs/__init__.py
Normal file
5
AutoDL-Projects/xautodl/nas_infer_model/DXYs/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Performance-Aware Template Network for One-Shot Neural Architecture Search
|
||||
from .CifarNet import NetworkCIFAR as CifarNet
|
||||
from .ImageNet import NetworkImageNet as ImageNet
|
||||
from .genotypes import Networks
|
||||
from .genotypes import build_genotype_from_dict
|
||||
173
AutoDL-Projects/xautodl/nas_infer_model/DXYs/base_cells.py
Normal file
173
AutoDL-Projects/xautodl/nas_infer_model/DXYs/base_cells.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import math
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .construct_utils import drop_path
|
||||
from ..operations import OPS, Identity, FactorizedReduce, ReLUConvBN
|
||||
|
||||
|
||||
class MixedOp(nn.Module):
|
||||
|
||||
def __init__(self, C, stride, PRIMITIVES):
|
||||
super(MixedOp, self).__init__()
|
||||
self._ops = nn.ModuleList()
|
||||
self.name2idx = {}
|
||||
for idx, primitive in enumerate(PRIMITIVES):
|
||||
op = OPS[primitive](C, C, stride, False)
|
||||
self._ops.append(op)
|
||||
assert primitive not in self.name2idx, '{:} has already in'.format(primitive)
|
||||
self.name2idx[primitive] = idx
|
||||
|
||||
def forward(self, x, weights, op_name):
|
||||
if op_name is None:
|
||||
if weights is None:
|
||||
return [op(x) for op in self._ops]
|
||||
else:
|
||||
return sum(w * op(x) for w, op in zip(weights, self._ops))
|
||||
else:
|
||||
op_index = self.name2idx[op_name]
|
||||
return self._ops[op_index](x)
|
||||
|
||||
|
||||
|
||||
class SearchCell(nn.Module):
|
||||
|
||||
def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, PRIMITIVES, use_residual):
|
||||
super(SearchCell, self).__init__()
|
||||
self.reduction = reduction
|
||||
self.PRIMITIVES = deepcopy(PRIMITIVES)
|
||||
|
||||
if reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2, affine=False)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
|
||||
self._steps = steps
|
||||
self._multiplier = multiplier
|
||||
self._use_residual = use_residual
|
||||
|
||||
self._ops = nn.ModuleList()
|
||||
for i in range(self._steps):
|
||||
for j in range(2+i):
|
||||
stride = 2 if reduction and j < 2 else 1
|
||||
op = MixedOp(C, stride, self.PRIMITIVES)
|
||||
self._ops.append(op)
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(residual={_use_residual}, steps={_steps}, multiplier={_multiplier})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def forward(self, S0, S1, weights, connect, adjacency, drop_prob, modes):
|
||||
if modes[0] is None:
|
||||
if modes[1] == 'normal':
|
||||
output = self.__forwardBoth(S0, S1, weights, connect, adjacency, drop_prob)
|
||||
elif modes[1] == 'only_W':
|
||||
output = self.__forwardOnlyW(S0, S1, drop_prob)
|
||||
else:
|
||||
test_genotype = modes[0]
|
||||
if self.reduction: operations, concats = test_genotype.reduce, test_genotype.reduce_concat
|
||||
else : operations, concats = test_genotype.normal, test_genotype.normal_concat
|
||||
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
|
||||
states, offset = [s0, s1], 0
|
||||
assert self._steps == len(operations), '{:} vs. {:}'.format(self._steps, len(operations))
|
||||
for i, (opA, opB) in enumerate(operations):
|
||||
A = self._ops[offset + opA[1]](states[opA[1]], None, opA[0])
|
||||
B = self._ops[offset + opB[1]](states[opB[1]], None, opB[0])
|
||||
state = A + B
|
||||
offset += len(states)
|
||||
states.append(state)
|
||||
output = torch.cat([states[i] for i in concats], dim=1)
|
||||
if self._use_residual and S1.size() == output.size():
|
||||
return S1 + output
|
||||
else: return output
|
||||
|
||||
def __forwardBoth(self, S0, S1, weights, connect, adjacency, drop_prob):
|
||||
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
|
||||
states, offset = [s0, s1], 0
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
x = self._ops[offset+j](h, weights[offset+j], None)
|
||||
if self.training and drop_prob > 0.:
|
||||
x = drop_path(x, math.pow(drop_prob, 1./len(states)))
|
||||
clist.append( x )
|
||||
connection = torch.mm(connect['{:}'.format(i)], adjacency[i]).squeeze(0)
|
||||
state = sum(w * node for w, node in zip(connection, clist))
|
||||
offset += len(states)
|
||||
states.append(state)
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
def __forwardOnlyW(self, S0, S1, drop_prob):
|
||||
s0, s1 = self.preprocess0(S0), self.preprocess1(S1)
|
||||
states, offset = [s0, s1], 0
|
||||
for i in range(self._steps):
|
||||
clist = []
|
||||
for j, h in enumerate(states):
|
||||
xs = self._ops[offset+j](h, None, None)
|
||||
clist += xs
|
||||
if self.training and drop_prob > 0.:
|
||||
xlist = [drop_path(x, math.pow(drop_prob, 1./len(states))) for x in clist]
|
||||
else: xlist = clist
|
||||
state = sum(xlist) * 2 / len(xlist)
|
||||
offset += len(states)
|
||||
states.append(state)
|
||||
return torch.cat(states[-self._multiplier:], dim=1)
|
||||
|
||||
|
||||
|
||||
class InferCell(nn.Module):
|
||||
|
||||
def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||
super(InferCell, self).__init__()
|
||||
print(C_prev_prev, C_prev, C)
|
||||
|
||||
if reduction_prev is None:
|
||||
self.preprocess0 = Identity()
|
||||
elif reduction_prev:
|
||||
self.preprocess0 = FactorizedReduce(C_prev_prev, C, 2)
|
||||
else:
|
||||
self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
|
||||
self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
|
||||
|
||||
if reduction: step_ops, concat = genotype.reduce, genotype.reduce_concat
|
||||
else : step_ops, concat = genotype.normal, genotype.normal_concat
|
||||
self._steps = len(step_ops)
|
||||
self._concat = concat
|
||||
self._multiplier = len(concat)
|
||||
self._ops = nn.ModuleList()
|
||||
self._indices = []
|
||||
for operations in step_ops:
|
||||
for name, index in operations:
|
||||
stride = 2 if reduction and index < 2 else 1
|
||||
if reduction_prev is None and index == 0:
|
||||
op = OPS[name](C_prev_prev, C, stride, True)
|
||||
else:
|
||||
op = OPS[name](C , C, stride, True)
|
||||
self._ops.append( op )
|
||||
self._indices.append( index )
|
||||
|
||||
def extra_repr(self):
|
||||
return ('{name}(steps={_steps}, concat={_concat})'.format(name=self.__class__.__name__, **self.__dict__))
|
||||
|
||||
def forward(self, S0, S1, drop_prob):
|
||||
s0 = self.preprocess0(S0)
|
||||
s1 = self.preprocess1(S1)
|
||||
|
||||
states = [s0, s1]
|
||||
for i in range(self._steps):
|
||||
h1 = states[self._indices[2*i]]
|
||||
h2 = states[self._indices[2*i+1]]
|
||||
op1 = self._ops[2*i]
|
||||
op2 = self._ops[2*i+1]
|
||||
h1 = op1(h1)
|
||||
h2 = op2(h2)
|
||||
if self.training and drop_prob > 0.:
|
||||
if not isinstance(op1, Identity):
|
||||
h1 = drop_path(h1, drop_prob)
|
||||
if not isinstance(op2, Identity):
|
||||
h2 = drop_path(h2, drop_prob)
|
||||
|
||||
state = h1 + h2
|
||||
states += [state]
|
||||
output = torch.cat([states[i] for i in self._concat], dim=1)
|
||||
return output
|
||||
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def drop_path(x, drop_prob):
|
||||
if drop_prob > 0.:
|
||||
keep_prob = 1. - drop_prob
|
||||
mask = x.new_zeros(x.size(0), 1, 1, 1)
|
||||
mask = mask.bernoulli_(keep_prob)
|
||||
x = torch.div(x, keep_prob)
|
||||
x.mul_(mask)
|
||||
return x
|
||||
|
||||
|
||||
def return_alphas_str(basemodel):
|
||||
if hasattr(basemodel, 'alphas_normal'):
|
||||
string = 'normal [{:}] : \n-->>{:}'.format(basemodel.alphas_normal.size(), F.softmax(basemodel.alphas_normal, dim=-1) )
|
||||
else: string = ''
|
||||
if hasattr(basemodel, 'alphas_reduce'):
|
||||
string = string + '\nreduce : {:}'.format( F.softmax(basemodel.alphas_reduce, dim=-1) )
|
||||
|
||||
if hasattr(basemodel, 'get_adjacency'):
|
||||
adjacency = basemodel.get_adjacency()
|
||||
for i in range( len(adjacency) ):
|
||||
weight = F.softmax( basemodel.connect_normal[str(i)], dim=-1 )
|
||||
adj = torch.mm(weight, adjacency[i]).view(-1)
|
||||
adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()]
|
||||
string = string + '\nnormal--{:}-->{:}'.format(i, ', '.join(adj))
|
||||
for i in range( len(adjacency) ):
|
||||
weight = F.softmax( basemodel.connect_reduce[str(i)], dim=-1 )
|
||||
adj = torch.mm(weight, adjacency[i]).view(-1)
|
||||
adj = ['{:3.3f}'.format(x) for x in adj.cpu().tolist()]
|
||||
string = string + '\nreduce--{:}-->{:}'.format(i, ', '.join(adj))
|
||||
|
||||
if hasattr(basemodel, 'alphas_connect'):
|
||||
weight = F.softmax(basemodel.alphas_connect, dim=-1).cpu()
|
||||
ZERO = ['{:.3f}'.format(x) for x in weight[:,0].tolist()]
|
||||
IDEN = ['{:.3f}'.format(x) for x in weight[:,1].tolist()]
|
||||
string = string + '\nconnect [{:}] : \n ->{:}\n ->{:}'.format( list(basemodel.alphas_connect.size()), ZERO, IDEN )
|
||||
else:
|
||||
string = string + '\nconnect = None'
|
||||
|
||||
if hasattr(basemodel, 'get_gcn_out'):
|
||||
outputs = basemodel.get_gcn_out(True)
|
||||
for i, output in enumerate(outputs):
|
||||
string = string + '\nnormal:[{:}] : {:}'.format(i, F.softmax(output, dim=-1) )
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def remove_duplicate_archs(all_archs):
|
||||
archs = []
|
||||
str_archs = ['{:}'.format(x) for x in all_archs]
|
||||
for i, arch_x in enumerate(str_archs):
|
||||
choose = True
|
||||
for j in range(i):
|
||||
if arch_x == str_archs[j]:
|
||||
choose = False; break
|
||||
if choose: archs.append(all_archs[i])
|
||||
return archs
|
||||
182
AutoDL-Projects/xautodl/nas_infer_model/DXYs/genotypes.py
Normal file
182
AutoDL-Projects/xautodl/nas_infer_model/DXYs/genotypes.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from collections import namedtuple
|
||||
|
||||
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat connectN connects')
|
||||
#Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
|
||||
|
||||
PRIMITIVES_small = [
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'conv_3x1_1x3',
|
||||
]
|
||||
|
||||
PRIMITIVES_large = [
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'skip_connect',
|
||||
'sep_conv_3x3',
|
||||
'sep_conv_5x5',
|
||||
'dil_conv_3x3',
|
||||
'dil_conv_5x5',
|
||||
'conv_3x1_1x3',
|
||||
]
|
||||
|
||||
PRIMITIVES_huge = [
|
||||
'skip_connect',
|
||||
'nor_conv_1x1',
|
||||
'max_pool_3x3',
|
||||
'avg_pool_3x3',
|
||||
'nor_conv_3x3',
|
||||
'sep_conv_3x3',
|
||||
'dil_conv_3x3',
|
||||
'conv_3x1_1x3',
|
||||
'sep_conv_5x5',
|
||||
'dil_conv_5x5',
|
||||
'sep_conv_7x7',
|
||||
'conv_7x1_1x7',
|
||||
'att_squeeze',
|
||||
]
|
||||
|
||||
PRIMITIVES = {'small': PRIMITIVES_small,
|
||||
'large': PRIMITIVES_large,
|
||||
'huge' : PRIMITIVES_huge}
|
||||
|
||||
NASNet = Genotype(
|
||||
normal = [
|
||||
(('sep_conv_5x5', 1), ('sep_conv_3x3', 0)),
|
||||
(('sep_conv_5x5', 0), ('sep_conv_3x3', 0)),
|
||||
(('avg_pool_3x3', 1), ('skip_connect', 0)),
|
||||
(('avg_pool_3x3', 0), ('avg_pool_3x3', 0)),
|
||||
(('sep_conv_3x3', 1), ('skip_connect', 1)),
|
||||
],
|
||||
normal_concat = [2, 3, 4, 5, 6],
|
||||
reduce = [
|
||||
(('sep_conv_5x5', 1), ('sep_conv_7x7', 0)),
|
||||
(('max_pool_3x3', 1), ('sep_conv_7x7', 0)),
|
||||
(('avg_pool_3x3', 1), ('sep_conv_5x5', 0)),
|
||||
(('skip_connect', 3), ('avg_pool_3x3', 2)),
|
||||
(('sep_conv_3x3', 2), ('max_pool_3x3', 1)),
|
||||
],
|
||||
reduce_concat = [4, 5, 6],
|
||||
connectN=None, connects=None,
|
||||
)
|
||||
|
||||
PNASNet = Genotype(
|
||||
normal = [
|
||||
(('sep_conv_5x5', 0), ('max_pool_3x3', 0)),
|
||||
(('sep_conv_7x7', 1), ('max_pool_3x3', 1)),
|
||||
(('sep_conv_5x5', 1), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_3x3', 4), ('max_pool_3x3', 1)),
|
||||
(('sep_conv_3x3', 0), ('skip_connect', 1)),
|
||||
],
|
||||
normal_concat = [2, 3, 4, 5, 6],
|
||||
reduce = [
|
||||
(('sep_conv_5x5', 0), ('max_pool_3x3', 0)),
|
||||
(('sep_conv_7x7', 1), ('max_pool_3x3', 1)),
|
||||
(('sep_conv_5x5', 1), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_3x3', 4), ('max_pool_3x3', 1)),
|
||||
(('sep_conv_3x3', 0), ('skip_connect', 1)),
|
||||
],
|
||||
reduce_concat = [2, 3, 4, 5, 6],
|
||||
connectN=None, connects=None,
|
||||
)
|
||||
|
||||
|
||||
DARTS_V1 = Genotype(
|
||||
normal=[
|
||||
(('sep_conv_3x3', 1), ('sep_conv_3x3', 0)), # step 1
|
||||
(('skip_connect', 0), ('sep_conv_3x3', 1)), # step 2
|
||||
(('skip_connect', 0), ('sep_conv_3x3', 1)), # step 3
|
||||
(('sep_conv_3x3', 0), ('skip_connect', 2)) # step 4
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
(('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1
|
||||
(('skip_connect', 2), ('max_pool_3x3', 0)), # step 2
|
||||
(('max_pool_3x3', 0), ('skip_connect', 2)), # step 3
|
||||
(('skip_connect', 2), ('avg_pool_3x3', 0)) # step 4
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5],
|
||||
connectN=None, connects=None,
|
||||
)
|
||||
|
||||
# DARTS: Differentiable Architecture Search, ICLR 2019
|
||||
DARTS_V2 = Genotype(
|
||||
normal=[
|
||||
(('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 1
|
||||
(('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 2
|
||||
(('sep_conv_3x3', 1), ('skip_connect', 0)), # step 3
|
||||
(('skip_connect', 0), ('dil_conv_3x3', 2)) # step 4
|
||||
],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
(('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1
|
||||
(('skip_connect', 2), ('max_pool_3x3', 1)), # step 2
|
||||
(('max_pool_3x3', 0), ('skip_connect', 2)), # step 3
|
||||
(('skip_connect', 2), ('max_pool_3x3', 1)) # step 4
|
||||
],
|
||||
reduce_concat=[2, 3, 4, 5],
|
||||
connectN=None, connects=None,
|
||||
)
|
||||
|
||||
|
||||
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019
|
||||
SETN = Genotype(
|
||||
normal=[
|
||||
(('skip_connect', 0), ('sep_conv_5x5', 1)),
|
||||
(('sep_conv_5x5', 0), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_5x5', 1), ('sep_conv_5x5', 3)),
|
||||
(('max_pool_3x3', 1), ('conv_3x1_1x3', 4))],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
(('sep_conv_3x3', 0), ('sep_conv_5x5', 1)),
|
||||
(('avg_pool_3x3', 0), ('sep_conv_5x5', 1)),
|
||||
(('avg_pool_3x3', 0), ('sep_conv_5x5', 1)),
|
||||
(('avg_pool_3x3', 0), ('skip_connect', 1))],
|
||||
reduce_concat=[2, 3, 4, 5],
|
||||
connectN=None, connects=None
|
||||
)
|
||||
|
||||
|
||||
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019
|
||||
GDAS_V1 = Genotype(
|
||||
normal=[
|
||||
(('skip_connect', 0), ('skip_connect', 1)),
|
||||
(('skip_connect', 0), ('sep_conv_5x5', 2)),
|
||||
(('sep_conv_3x3', 3), ('skip_connect', 0)),
|
||||
(('sep_conv_5x5', 4), ('sep_conv_3x3', 3))],
|
||||
normal_concat=[2, 3, 4, 5],
|
||||
reduce=[
|
||||
(('sep_conv_5x5', 0), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_5x5', 2), ('sep_conv_5x5', 1)),
|
||||
(('dil_conv_5x5', 2), ('sep_conv_3x3', 1)),
|
||||
(('sep_conv_5x5', 0), ('sep_conv_5x5', 1))],
|
||||
reduce_concat=[2, 3, 4, 5],
|
||||
connectN=None, connects=None
|
||||
)
|
||||
|
||||
|
||||
|
||||
Networks = {'DARTS_V1': DARTS_V1,
|
||||
'DARTS_V2': DARTS_V2,
|
||||
'DARTS' : DARTS_V2,
|
||||
'NASNet' : NASNet,
|
||||
'GDAS_V1' : GDAS_V1,
|
||||
'PNASNet' : PNASNet,
|
||||
'SETN' : SETN,
|
||||
}
|
||||
|
||||
# This function will return a Genotype from a dict.
|
||||
def build_genotype_from_dict(xdict):
|
||||
def remove_value(nodes):
|
||||
return [tuple([(x[0], x[1]) for x in node]) for node in nodes]
|
||||
genotype = Genotype(
|
||||
normal=remove_value(xdict['normal']),
|
||||
normal_concat=xdict['normal_concat'],
|
||||
reduce=remove_value(xdict['reduce']),
|
||||
reduce_concat=xdict['reduce_concat'],
|
||||
connectN=None, connects=None
|
||||
)
|
||||
return genotype
|
||||
71
AutoDL-Projects/xautodl/nas_infer_model/DXYs/head_utils.py
Normal file
71
AutoDL-Projects/xautodl/nas_infer_model/DXYs/head_utils.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ImageNetHEAD(nn.Sequential):
|
||||
def __init__(self, C, stride=2):
|
||||
super(ImageNetHEAD, self).__init__()
|
||||
self.add_module(
|
||||
"conv1",
|
||||
nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
)
|
||||
self.add_module("bn1", nn.BatchNorm2d(C // 2))
|
||||
self.add_module("relu1", nn.ReLU(inplace=True))
|
||||
self.add_module(
|
||||
"conv2",
|
||||
nn.Conv2d(C // 2, C, kernel_size=3, stride=stride, padding=1, bias=False),
|
||||
)
|
||||
self.add_module("bn2", nn.BatchNorm2d(C))
|
||||
|
||||
|
||||
class CifarHEAD(nn.Sequential):
|
||||
def __init__(self, C):
|
||||
super(CifarHEAD, self).__init__()
|
||||
self.add_module("conv", nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False))
|
||||
self.add_module("bn", nn.BatchNorm2d(C))
|
||||
|
||||
|
||||
class AuxiliaryHeadCIFAR(nn.Module):
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 8x8"""
|
||||
super(AuxiliaryHeadCIFAR, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(
|
||||
5, stride=3, padding=0, count_include_pad=False
|
||||
), # image size = 2 x 2
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
return x
|
||||
|
||||
|
||||
class AuxiliaryHeadImageNet(nn.Module):
|
||||
def __init__(self, C, num_classes):
|
||||
"""assuming input size 14x14"""
|
||||
super(AuxiliaryHeadImageNet, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
|
||||
nn.Conv2d(C, 128, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(128, 768, 2, bias=False),
|
||||
nn.BatchNorm2d(768),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.classifier = nn.Linear(768, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x.view(x.size(0), -1))
|
||||
return x
|
||||
51
AutoDL-Projects/xautodl/nas_infer_model/__init__.py
Normal file
51
AutoDL-Projects/xautodl/nas_infer_model/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
# I write this package to make AutoDL-Projects to be compatible with the old GDAS projects.
|
||||
# Ideally, this package will be merged into lib/models/cell_infers in future.
|
||||
# Currently, this package is used to reproduce the results in GDAS (Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019).
|
||||
##################################################
|
||||
|
||||
import os, torch
|
||||
|
||||
|
||||
def obtain_nas_infer_model(config, extra_model_path=None):
|
||||
|
||||
if config.arch == "dxys":
|
||||
from .DXYs import CifarNet, ImageNet, Networks
|
||||
from .DXYs import build_genotype_from_dict
|
||||
|
||||
if config.genotype is None:
|
||||
if extra_model_path is not None and not os.path.isfile(extra_model_path):
|
||||
raise ValueError(
|
||||
"When genotype in confiig is None, extra_model_path must be set as a path instead of {:}".format(
|
||||
extra_model_path
|
||||
)
|
||||
)
|
||||
xdata = torch.load(extra_model_path)
|
||||
current_epoch = xdata["epoch"]
|
||||
genotype_dict = xdata["genotypes"][current_epoch - 1]
|
||||
genotype = build_genotype_from_dict(genotype_dict)
|
||||
else:
|
||||
genotype = Networks[config.genotype]
|
||||
if config.dataset == "cifar":
|
||||
return CifarNet(
|
||||
config.ichannel,
|
||||
config.layers,
|
||||
config.stem_multi,
|
||||
config.auxiliary,
|
||||
genotype,
|
||||
config.class_num,
|
||||
)
|
||||
elif config.dataset == "imagenet":
|
||||
return ImageNet(
|
||||
config.ichannel,
|
||||
config.layers,
|
||||
config.auxiliary,
|
||||
genotype,
|
||||
config.class_num,
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid dataset : {:}".format(config.dataset))
|
||||
else:
|
||||
raise ValueError("invalid nas arch type : {:}".format(config.arch))
|
||||
183
AutoDL-Projects/xautodl/nas_infer_model/operations.py
Normal file
183
AutoDL-Projects/xautodl/nas_infer_model/operations.py
Normal file
@@ -0,0 +1,183 @@
|
||||
##############################################################################################
|
||||
# This code is copied and modified from Hanxiao Liu's work (https://github.com/quark0/darts) #
|
||||
##############################################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
OPS = {
|
||||
'none' : lambda C_in, C_out, stride, affine: Zero(stride),
|
||||
'avg_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg'),
|
||||
'max_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'max'),
|
||||
'nor_conv_7x7' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), affine),
|
||||
'nor_conv_3x3' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), affine),
|
||||
'nor_conv_1x1' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), affine),
|
||||
'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine),
|
||||
'sep_conv_3x3' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 3, stride, 1, affine=affine),
|
||||
'sep_conv_5x5' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 5, stride, 2, affine=affine),
|
||||
'sep_conv_7x7' : lambda C_in, C_out, stride, affine: SepConv(C_in, C_out, 7, stride, 3, affine=affine),
|
||||
'dil_conv_3x3' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 3, stride, 2, 2, affine=affine),
|
||||
'dil_conv_5x5' : lambda C_in, C_out, stride, affine: DilConv(C_in, C_out, 5, stride, 4, 2, affine=affine),
|
||||
'conv_7x1_1x7' : lambda C_in, C_out, stride, affine: Conv717(C_in, C_out, stride, affine),
|
||||
'conv_3x1_1x3' : lambda C_in, C_out, stride, affine: Conv313(C_in, C_out, stride, affine)
|
||||
}
|
||||
|
||||
|
||||
class POOLING(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, mode):
|
||||
super(POOLING, self).__init__()
|
||||
if C_in == C_out:
|
||||
self.preprocess = None
|
||||
else:
|
||||
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0)
|
||||
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
|
||||
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.preprocess is not None:
|
||||
x = self.preprocess(inputs)
|
||||
else: x = inputs
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Conv313(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, affine):
|
||||
super(Conv313, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in , C_out, (1,3), stride=(1, stride), padding=(0, 1), bias=False),
|
||||
nn.Conv2d(C_out, C_out, (3,1), stride=(stride, 1), padding=(1, 0), bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Conv717(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, affine):
|
||||
super(Conv717, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in , C_out, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
|
||||
nn.Conv2d(C_out, C_out, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class ReLUConvBN(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(ReLUConvBN, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class DilConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
|
||||
super(DilConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
|
||||
super(SepConv, self).__init__()
|
||||
self.op = nn.Sequential(
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_in, affine=affine),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride= 1, padding=padding, groups=C_in, bias=False),
|
||||
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
|
||||
nn.BatchNorm2d(C_out, affine=affine),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Zero(nn.Module):
|
||||
|
||||
def __init__(self, stride):
|
||||
super(Zero, self).__init__()
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
if self.stride == 1:
|
||||
return x.mul(0.)
|
||||
return x[:,:,::self.stride,::self.stride].mul(0.)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'stride={stride}'.format(**self.__dict__)
|
||||
|
||||
|
||||
class FactorizedReduce(nn.Module):
|
||||
|
||||
def __init__(self, C_in, C_out, stride, affine=True):
|
||||
super(FactorizedReduce, self).__init__()
|
||||
self.stride = stride
|
||||
self.C_in = C_in
|
||||
self.C_out = C_out
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
if stride == 2:
|
||||
#assert C_out % 2 == 0, 'C_out : {:}'.format(C_out)
|
||||
C_outs = [C_out // 2, C_out - C_out // 2]
|
||||
self.convs = nn.ModuleList()
|
||||
for i in range(2):
|
||||
self.convs.append( nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False) )
|
||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||
elif stride == 4:
|
||||
assert C_out % 4 == 0, 'C_out : {:}'.format(C_out)
|
||||
self.convs = nn.ModuleList()
|
||||
for i in range(4):
|
||||
self.convs.append( nn.Conv2d(C_in, C_out // 4, 1, stride=stride, padding=0, bias=False) )
|
||||
self.pad = nn.ConstantPad2d((0, 3, 0, 3), 0)
|
||||
else:
|
||||
raise ValueError('Invalid stride : {:}'.format(stride))
|
||||
|
||||
self.bn = nn.BatchNorm2d(C_out, affine=affine)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
y = self.pad(x)
|
||||
if self.stride == 2:
|
||||
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:,1:])], dim=1)
|
||||
else:
|
||||
out = torch.cat([self.convs[0](x), self.convs[1](y[:,:,1:-2,1:-2]),
|
||||
self.convs[2](y[:,:,2:-1,2:-1]), self.convs[3](y[:,:,3:,3:])], dim=1)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
def extra_repr(self):
|
||||
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
|
||||
38
AutoDL-Projects/xautodl/procedures/__init__.py
Normal file
38
AutoDL-Projects/xautodl/procedures/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
######################################################################
|
||||
# This folder is deprecated, which is re-organized in "xalgorithms". #
|
||||
######################################################################
|
||||
from .starts import prepare_seed
|
||||
from .starts import prepare_logger
|
||||
from .starts import get_machine_info
|
||||
from .starts import save_checkpoint
|
||||
from .starts import copy_checkpoint
|
||||
from .optimizers import get_optim_scheduler
|
||||
from .funcs_nasbench import evaluate_for_seed as bench_evaluate_for_seed
|
||||
from .funcs_nasbench import pure_evaluate as bench_pure_evaluate
|
||||
from .funcs_nasbench import get_nas_bench_loaders
|
||||
|
||||
|
||||
def get_procedures(procedure):
|
||||
from .basic_main import basic_train, basic_valid
|
||||
from .search_main import search_train, search_valid
|
||||
from .search_main_v2 import search_train_v2
|
||||
from .simple_KD_main import simple_KD_train, simple_KD_valid
|
||||
|
||||
train_funcs = {
|
||||
"basic": basic_train,
|
||||
"search": search_train,
|
||||
"Simple-KD": simple_KD_train,
|
||||
"search-v2": search_train_v2,
|
||||
}
|
||||
valid_funcs = {
|
||||
"basic": basic_valid,
|
||||
"search": search_valid,
|
||||
"Simple-KD": simple_KD_valid,
|
||||
"search-v2": search_valid,
|
||||
}
|
||||
|
||||
train_func = train_funcs[procedure]
|
||||
valid_func = valid_funcs[procedure]
|
||||
return train_func, valid_func
|
||||
99
AutoDL-Projects/xautodl/procedures/advanced_main.py
Normal file
99
AutoDL-Projects/xautodl/procedures/advanced_main.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
|
||||
#####################################################
|
||||
# To be finished.
|
||||
#
|
||||
import os, sys, time, torch
|
||||
from typing import Optional, Text, Callable
|
||||
|
||||
# modules in AutoDL
|
||||
from xautodl.log_utils import AverageMeter, time_string
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
def get_device(tensors):
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return get_device(tensors[0])
|
||||
elif isinstance(tensors, dict):
|
||||
for key, value in tensors.items():
|
||||
return get_device(value)
|
||||
else:
|
||||
return tensors.device
|
||||
|
||||
|
||||
def basic_train_fn(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
optimizer,
|
||||
metric,
|
||||
logger,
|
||||
):
|
||||
results = procedure(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
optimizer,
|
||||
metric,
|
||||
"train",
|
||||
logger,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def basic_eval_fn(xloader, network, metric, logger):
|
||||
with torch.no_grad():
|
||||
results = procedure(
|
||||
xloader,
|
||||
network,
|
||||
None,
|
||||
None,
|
||||
metric,
|
||||
"valid",
|
||||
logger,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def procedure(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
optimizer,
|
||||
metric,
|
||||
mode: Text,
|
||||
logger_fn: Callable = None,
|
||||
):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
if mode.lower() == "train":
|
||||
network.train()
|
||||
elif mode.lower() == "valid":
|
||||
network.eval()
|
||||
else:
|
||||
raise ValueError("The mode is not right : {:}".format(mode))
|
||||
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
|
||||
if mode == "train":
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs = network(inputs)
|
||||
targets = targets.to(get_device(outputs))
|
||||
|
||||
if mode == "train":
|
||||
loss = criterion(outputs, targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# record
|
||||
with torch.no_grad():
|
||||
results = metric(outputs, targets)
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return metric.get_info()
|
||||
154
AutoDL-Projects/xautodl/procedures/basic_main.py
Normal file
154
AutoDL-Projects/xautodl/procedures/basic_main.py
Normal file
@@ -0,0 +1,154 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
|
||||
# modules in AutoDL
|
||||
from xautodl.log_utils import AverageMeter, time_string
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
def basic_train(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
loss, acc1, acc5 = procedure(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
"train",
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def basic_valid(
|
||||
xloader, network, criterion, optim_config, extra_info, print_freq, logger
|
||||
):
|
||||
with torch.no_grad():
|
||||
loss, acc1, acc5 = procedure(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
None,
|
||||
None,
|
||||
"valid",
|
||||
None,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def procedure(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
mode,
|
||||
config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
data_time, batch_time, losses, top1, top5 = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
)
|
||||
if mode == "train":
|
||||
network.train()
|
||||
elif mode == "valid":
|
||||
network.eval()
|
||||
else:
|
||||
raise ValueError("The mode is not right : {:}".format(mode))
|
||||
|
||||
# logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message()))
|
||||
logger.log(
|
||||
"[{:5s}] config :: auxiliary={:}".format(
|
||||
mode, config.auxiliary if hasattr(config, "auxiliary") else -1
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == "train":
|
||||
scheduler.update(None, 1.0 * i / len(xloader))
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
if mode == "train":
|
||||
optimizer.zero_grad()
|
||||
|
||||
features, logits = network(inputs)
|
||||
if isinstance(logits, list):
|
||||
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(
|
||||
2, len(logits)
|
||||
)
|
||||
logits, logits_aux = logits
|
||||
else:
|
||||
logits, logits_aux = logits, None
|
||||
loss = criterion(logits, targets)
|
||||
if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0:
|
||||
loss_aux = criterion(logits_aux, targets)
|
||||
loss += config.auxiliary * loss_aux
|
||||
|
||||
if mode == "train":
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i + 1) == len(xloader):
|
||||
Sstr = (
|
||||
" {:5s} ".format(mode.upper())
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
|
||||
)
|
||||
if scheduler is not None:
|
||||
Sstr += " {:}".format(scheduler.get_min_info())
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
|
||||
loss=losses, top1=top1, top5=top5
|
||||
)
|
||||
Istr = "Size={:}".format(list(inputs.size()))
|
||||
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr)
|
||||
|
||||
logger.log(
|
||||
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
||||
mode=mode.upper(),
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
error1=100 - top1.avg,
|
||||
error5=100 - top5.avg,
|
||||
loss=losses.avg,
|
||||
)
|
||||
)
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
20
AutoDL-Projects/xautodl/procedures/eval_funcs.py
Normal file
20
AutoDL-Projects/xautodl/procedures/eval_funcs.py
Normal file
@@ -0,0 +1,20 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
|
||||
#####################################################
|
||||
import abc
|
||||
|
||||
|
||||
def obtain_accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
437
AutoDL-Projects/xautodl/procedures/funcs_nasbench.py
Normal file
437
AutoDL-Projects/xautodl/procedures/funcs_nasbench.py
Normal file
@@ -0,0 +1,437 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||
#####################################################
|
||||
import os, time, copy, torch, pathlib
|
||||
|
||||
from xautodl import datasets
|
||||
from xautodl.config_utils import load_config
|
||||
from xautodl.procedures import prepare_seed, get_optim_scheduler
|
||||
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from xautodl.models import get_cell_based_tiny_net
|
||||
from xautodl.utils import get_model_infos
|
||||
from xautodl.procedures.eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
__all__ = ["evaluate_for_seed", "pure_evaluate", "get_nas_bench_loaders"]
|
||||
|
||||
|
||||
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
|
||||
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
latencies, device = [], torch.cuda.current_device()
|
||||
network.eval()
|
||||
with torch.no_grad():
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
targets = targets.cuda(device=device, non_blocking=True)
|
||||
inputs = inputs.cuda(device=device, non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
batch_time.update(time.time() - end)
|
||||
if batch is None or batch == inputs.size(0):
|
||||
batch = inputs.size(0)
|
||||
latencies.append(batch_time.val - data_time.val)
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
end = time.time()
|
||||
if len(latencies) > 2:
|
||||
latencies = latencies[1:]
|
||||
return losses.avg, top1.avg, top5.avg, latencies
|
||||
|
||||
|
||||
def procedure(xloader, network, criterion, scheduler, optimizer, mode: str):
|
||||
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
if mode == "train":
|
||||
network.train()
|
||||
elif mode == "valid":
|
||||
network.eval()
|
||||
else:
|
||||
raise ValueError("The mode is not right : {:}".format(mode))
|
||||
device = torch.cuda.current_device()
|
||||
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == "train":
|
||||
scheduler.update(None, 1.0 * i / len(xloader))
|
||||
|
||||
targets = targets.cuda(device=device, non_blocking=True)
|
||||
if mode == "train":
|
||||
optimizer.zero_grad()
|
||||
# forward
|
||||
features, logits = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
# backward
|
||||
if mode == "train":
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# record loss and accuracy
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
# count time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
||||
|
||||
|
||||
def evaluate_for_seed(
|
||||
arch_config, opt_config, train_loader, valid_loaders, seed: int, logger
|
||||
):
|
||||
"""A modular function to train and evaluate a single network, using the given random seed and optimization config with the provided loaders."""
|
||||
prepare_seed(seed) # random seed
|
||||
net = get_cell_based_tiny_net(arch_config)
|
||||
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
|
||||
flop, param = get_model_infos(net, opt_config.xshape)
|
||||
logger.log("Network : {:}".format(net.get_message()), False)
|
||||
logger.log(
|
||||
"{:} Seed-------------------------- {:} --------------------------".format(
|
||||
time_string(), seed
|
||||
)
|
||||
)
|
||||
logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param))
|
||||
# train and valid
|
||||
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config)
|
||||
default_device = torch.cuda.current_device()
|
||||
network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(
|
||||
device=default_device
|
||||
)
|
||||
criterion = criterion.cuda(device=default_device)
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = (
|
||||
time.time(),
|
||||
AverageMeter(),
|
||||
opt_config.epochs + opt_config.warmup,
|
||||
)
|
||||
(
|
||||
train_losses,
|
||||
train_acc1es,
|
||||
train_acc5es,
|
||||
valid_losses,
|
||||
valid_acc1es,
|
||||
valid_acc5es,
|
||||
) = ({}, {}, {}, {}, {}, {})
|
||||
train_times, valid_times, lrs = {}, {}, {}
|
||||
for epoch in range(total_epoch):
|
||||
scheduler.update(epoch, 0.0)
|
||||
lr = min(scheduler.get_lr())
|
||||
train_loss, train_acc1, train_acc5, train_tm = procedure(
|
||||
train_loader, network, criterion, scheduler, optimizer, "train"
|
||||
)
|
||||
train_losses[epoch] = train_loss
|
||||
train_acc1es[epoch] = train_acc1
|
||||
train_acc5es[epoch] = train_acc5
|
||||
train_times[epoch] = train_tm
|
||||
lrs[epoch] = lr
|
||||
with torch.no_grad():
|
||||
for key, xloder in valid_loaders.items():
|
||||
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(
|
||||
xloder, network, criterion, None, None, "valid"
|
||||
)
|
||||
valid_losses["{:}@{:}".format(key, epoch)] = valid_loss
|
||||
valid_acc1es["{:}@{:}".format(key, epoch)] = valid_acc1
|
||||
valid_acc5es["{:}@{:}".format(key, epoch)] = valid_acc5
|
||||
valid_times["{:}@{:}".format(key, epoch)] = valid_tm
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
need_time = "Time Left: {:}".format(
|
||||
convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True)
|
||||
)
|
||||
logger.log(
|
||||
"{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}".format(
|
||||
time_string(),
|
||||
need_time,
|
||||
epoch,
|
||||
total_epoch,
|
||||
train_loss,
|
||||
train_acc1,
|
||||
train_acc5,
|
||||
valid_loss,
|
||||
valid_acc1,
|
||||
valid_acc5,
|
||||
lr,
|
||||
)
|
||||
)
|
||||
info_seed = {
|
||||
"flop": flop,
|
||||
"param": param,
|
||||
"arch_config": arch_config._asdict(),
|
||||
"opt_config": opt_config._asdict(),
|
||||
"total_epoch": total_epoch,
|
||||
"train_losses": train_losses,
|
||||
"train_acc1es": train_acc1es,
|
||||
"train_acc5es": train_acc5es,
|
||||
"train_times": train_times,
|
||||
"valid_losses": valid_losses,
|
||||
"valid_acc1es": valid_acc1es,
|
||||
"valid_acc5es": valid_acc5es,
|
||||
"valid_times": valid_times,
|
||||
"learning_rates": lrs,
|
||||
"net_state_dict": net.state_dict(),
|
||||
"net_string": "{:}".format(net),
|
||||
"finish-train": True,
|
||||
}
|
||||
return info_seed
|
||||
|
||||
|
||||
def get_nas_bench_loaders(workers):
|
||||
|
||||
torch.set_num_threads(workers)
|
||||
|
||||
root_dir = (pathlib.Path(__file__).parent / ".." / "..").resolve()
|
||||
torch_dir = pathlib.Path(os.environ["TORCH_HOME"])
|
||||
# cifar
|
||||
cifar_config_path = root_dir / "configs" / "nas-benchmark" / "CIFAR.config"
|
||||
cifar_config = load_config(cifar_config_path, None, None)
|
||||
get_datasets = datasets.get_datasets # a function to return the dataset
|
||||
break_line = "-" * 150
|
||||
print("{:} Create data-loader for all datasets".format(time_string()))
|
||||
print(break_line)
|
||||
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets(
|
||||
"cifar10", str(torch_dir / "cifar.python"), -1
|
||||
)
|
||||
print(
|
||||
"original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
|
||||
len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num
|
||||
)
|
||||
)
|
||||
cifar10_splits = load_config(
|
||||
root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None
|
||||
)
|
||||
assert cifar10_splits.train[:10] == [
|
||||
0,
|
||||
5,
|
||||
7,
|
||||
11,
|
||||
13,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
20,
|
||||
24,
|
||||
] and cifar10_splits.valid[:10] == [
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
6,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
12,
|
||||
14,
|
||||
]
|
||||
temp_dataset = copy.deepcopy(TRAIN_CIFAR10)
|
||||
temp_dataset.transform = VALID_CIFAR10.transform
|
||||
# data loader
|
||||
trainval_cifar10_loader = torch.utils.data.DataLoader(
|
||||
TRAIN_CIFAR10,
|
||||
batch_size=cifar_config.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
train_cifar10_loader = torch.utils.data.DataLoader(
|
||||
TRAIN_CIFAR10,
|
||||
batch_size=cifar_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.train),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
valid_cifar10_loader = torch.utils.data.DataLoader(
|
||||
temp_dataset,
|
||||
batch_size=cifar_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar10_splits.valid),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
test__cifar10_loader = torch.utils.data.DataLoader(
|
||||
VALID_CIFAR10,
|
||||
batch_size=cifar_config.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
print(
|
||||
"CIFAR-10 : trval-loader has {:3d} batch with {:} per batch".format(
|
||||
len(trainval_cifar10_loader), cifar_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"CIFAR-10 : train-loader has {:3d} batch with {:} per batch".format(
|
||||
len(train_cifar10_loader), cifar_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"CIFAR-10 : valid-loader has {:3d} batch with {:} per batch".format(
|
||||
len(valid_cifar10_loader), cifar_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"CIFAR-10 : test--loader has {:3d} batch with {:} per batch".format(
|
||||
len(test__cifar10_loader), cifar_config.batch_size
|
||||
)
|
||||
)
|
||||
print(break_line)
|
||||
# CIFAR-100
|
||||
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets(
|
||||
"cifar100", str(torch_dir / "cifar.python"), -1
|
||||
)
|
||||
print(
|
||||
"original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
|
||||
len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num
|
||||
)
|
||||
)
|
||||
cifar100_splits = load_config(
|
||||
root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None
|
||||
)
|
||||
assert cifar100_splits.xvalid[:10] == [
|
||||
1,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
8,
|
||||
10,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
] and cifar100_splits.xtest[:10] == [
|
||||
0,
|
||||
2,
|
||||
6,
|
||||
7,
|
||||
9,
|
||||
11,
|
||||
12,
|
||||
17,
|
||||
20,
|
||||
24,
|
||||
]
|
||||
train_cifar100_loader = torch.utils.data.DataLoader(
|
||||
TRAIN_CIFAR100,
|
||||
batch_size=cifar_config.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
valid_cifar100_loader = torch.utils.data.DataLoader(
|
||||
VALID_CIFAR100,
|
||||
batch_size=cifar_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
test__cifar100_loader = torch.utils.data.DataLoader(
|
||||
VALID_CIFAR100,
|
||||
batch_size=cifar_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
print(
|
||||
"CIFAR-100 : train-loader has {:3d} batch".format(len(train_cifar100_loader))
|
||||
)
|
||||
print(
|
||||
"CIFAR-100 : valid-loader has {:3d} batch".format(len(valid_cifar100_loader))
|
||||
)
|
||||
print(
|
||||
"CIFAR-100 : test--loader has {:3d} batch".format(len(test__cifar100_loader))
|
||||
)
|
||||
print(break_line)
|
||||
|
||||
imagenet16_config_path = "configs/nas-benchmark/ImageNet-16.config"
|
||||
imagenet16_config = load_config(imagenet16_config_path, None, None)
|
||||
TRAIN_ImageNet16_120, VALID_ImageNet16_120, xshape, class_num = get_datasets(
|
||||
"ImageNet16-120", str(torch_dir / "cifar.python" / "ImageNet16"), -1
|
||||
)
|
||||
print(
|
||||
"original TRAIN_ImageNet16_120: {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
|
||||
len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num
|
||||
)
|
||||
)
|
||||
imagenet_splits = load_config(
|
||||
root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt",
|
||||
None,
|
||||
None,
|
||||
)
|
||||
assert imagenet_splits.xvalid[:10] == [
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
12,
|
||||
16,
|
||||
18,
|
||||
] and imagenet_splits.xtest[:10] == [
|
||||
0,
|
||||
4,
|
||||
5,
|
||||
10,
|
||||
11,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
17,
|
||||
20,
|
||||
]
|
||||
train_imagenet_loader = torch.utils.data.DataLoader(
|
||||
TRAIN_ImageNet16_120,
|
||||
batch_size=imagenet16_config.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
valid_imagenet_loader = torch.utils.data.DataLoader(
|
||||
VALID_ImageNet16_120,
|
||||
batch_size=imagenet16_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xvalid),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
test__imagenet_loader = torch.utils.data.DataLoader(
|
||||
VALID_ImageNet16_120,
|
||||
batch_size=imagenet16_config.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_splits.xtest),
|
||||
num_workers=workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
print(
|
||||
"ImageNet-16-120 : train-loader has {:3d} batch with {:} per batch".format(
|
||||
len(train_imagenet_loader), imagenet16_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"ImageNet-16-120 : valid-loader has {:3d} batch with {:} per batch".format(
|
||||
len(valid_imagenet_loader), imagenet16_config.batch_size
|
||||
)
|
||||
)
|
||||
print(
|
||||
"ImageNet-16-120 : test--loader has {:3d} batch with {:} per batch".format(
|
||||
len(test__imagenet_loader), imagenet16_config.batch_size
|
||||
)
|
||||
)
|
||||
|
||||
# 'cifar10', 'cifar100', 'ImageNet16-120'
|
||||
loaders = {
|
||||
"cifar10@trainval": trainval_cifar10_loader,
|
||||
"cifar10@train": train_cifar10_loader,
|
||||
"cifar10@valid": valid_cifar10_loader,
|
||||
"cifar10@test": test__cifar10_loader,
|
||||
"cifar100@train": train_cifar100_loader,
|
||||
"cifar100@valid": valid_cifar100_loader,
|
||||
"cifar100@test": test__cifar100_loader,
|
||||
"ImageNet16-120@train": train_imagenet_loader,
|
||||
"ImageNet16-120@valid": valid_imagenet_loader,
|
||||
"ImageNet16-120@test": test__imagenet_loader,
|
||||
}
|
||||
return loaders
|
||||
166
AutoDL-Projects/xautodl/procedures/metric_utils.py
Normal file
166
AutoDL-Projects/xautodl/procedures/metric_utils.py
Normal file
@@ -0,0 +1,166 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
|
||||
#####################################################
|
||||
import abc
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0.0
|
||||
self.avg = 0.0
|
||||
self.sum = 0.0
|
||||
self.count = 0.0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(val={val}, avg={avg}, count={count})".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
)
|
||||
|
||||
|
||||
class Metric(abc.ABC):
|
||||
"""The default meta metric class."""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, predictions, targets):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_info(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({inner})".format(
|
||||
name=self.__class__.__name__, inner=self.inner_repr()
|
||||
)
|
||||
|
||||
def inner_repr(self):
|
||||
return ""
|
||||
|
||||
|
||||
class ComposeMetric(Metric):
|
||||
"""The composed metric class."""
|
||||
|
||||
def __init__(self, *metric_list):
|
||||
self.reset()
|
||||
for metric in metric_list:
|
||||
self.append(metric)
|
||||
|
||||
def reset(self):
|
||||
self._metric_list = []
|
||||
|
||||
def append(self, metric):
|
||||
if not isinstance(metric, Metric):
|
||||
raise ValueError(
|
||||
"The input metric is not correct: {:}".format(type(metric))
|
||||
)
|
||||
self._metric_list.append(metric)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._metric_list)
|
||||
|
||||
def __call__(self, predictions, targets):
|
||||
results = list()
|
||||
for metric in self._metric_list:
|
||||
results.append(metric(predictions, targets))
|
||||
return results
|
||||
|
||||
def get_info(self):
|
||||
results = dict()
|
||||
for metric in self._metric_list:
|
||||
for key, value in metric.get_info().items():
|
||||
results[key] = value
|
||||
return results
|
||||
|
||||
def inner_repr(self):
|
||||
xlist = []
|
||||
for metric in self._metric_list:
|
||||
xlist.append(str(metric))
|
||||
return ",".join(xlist)
|
||||
|
||||
|
||||
class MSEMetric(Metric):
|
||||
"""The metric for mse."""
|
||||
|
||||
def __init__(self, ignore_batch):
|
||||
super(MSEMetric, self).__init__()
|
||||
self._ignore_batch = ignore_batch
|
||||
|
||||
def reset(self):
|
||||
self._mse = AverageMeter()
|
||||
|
||||
def __call__(self, predictions, targets):
|
||||
if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
|
||||
loss = torch.nn.functional.mse_loss(predictions.data, targets.data).item()
|
||||
if self._ignore_batch:
|
||||
self._mse.update(loss, 1)
|
||||
else:
|
||||
self._mse.update(loss, predictions.shape[0])
|
||||
return loss
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_info(self):
|
||||
return {"mse": self._mse.avg, "score": self._mse.avg}
|
||||
|
||||
|
||||
class Top1AccMetric(Metric):
|
||||
"""The metric for the top-1 accuracy."""
|
||||
|
||||
def __init__(self, ignore_batch):
|
||||
super(Top1AccMetric, self).__init__()
|
||||
self._ignore_batch = ignore_batch
|
||||
|
||||
def reset(self):
|
||||
self._accuracy = AverageMeter()
|
||||
|
||||
def __call__(self, predictions, targets):
|
||||
if isinstance(predictions, torch.Tensor) and isinstance(targets, torch.Tensor):
|
||||
max_prob_indexes = torch.argmax(predictions, dim=-1)
|
||||
corrects = torch.eq(max_prob_indexes, targets)
|
||||
accuracy = corrects.float().mean().float()
|
||||
if self._ignore_batch:
|
||||
self._accuracy.update(accuracy, 1)
|
||||
else: # [TODO] for 3-d tensor
|
||||
self._accuracy.update(accuracy, predictions.shape[0])
|
||||
return accuracy
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_info(self):
|
||||
return {"accuracy": self._accuracy.avg, "score": self._accuracy.avg * 100}
|
||||
|
||||
|
||||
class SaveMetric(Metric):
|
||||
"""The metric for mse."""
|
||||
|
||||
def reset(self):
|
||||
self._predicts = []
|
||||
|
||||
def __call__(self, predictions, targets=None):
|
||||
if isinstance(predictions, torch.Tensor):
|
||||
predicts = predictions.cpu().numpy()
|
||||
self._predicts.append(predicts)
|
||||
return predicts
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_info(self):
|
||||
all_predicts = np.concatenate(self._predicts)
|
||||
return {"predictions": all_predicts}
|
||||
263
AutoDL-Projects/xautodl/procedures/optimizers.py
Normal file
263
AutoDL-Projects/xautodl/procedures/optimizers.py
Normal file
@@ -0,0 +1,263 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import math, torch
|
||||
import torch.nn as nn
|
||||
from bisect import bisect_right
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class _LRScheduler(object):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs):
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
raise TypeError("{:} is not an Optimizer".format(type(optimizer).__name__))
|
||||
self.optimizer = optimizer
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault("initial_lr", group["lr"])
|
||||
self.base_lrs = list(
|
||||
map(lambda group: group["initial_lr"], optimizer.param_groups)
|
||||
)
|
||||
self.max_epochs = epochs
|
||||
self.warmup_epochs = warmup_epochs
|
||||
self.current_epoch = 0
|
||||
self.current_iter = 0
|
||||
|
||||
def extra_repr(self):
|
||||
return ""
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(warmup={warmup_epochs}, max-epoch={max_epochs}, current::epoch={current_epoch}, iter={current_iter:.2f}".format(
|
||||
name=self.__class__.__name__, **self.__dict__
|
||||
) + ", {:})".format(
|
||||
self.extra_repr()
|
||||
)
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
key: value for key, value in self.__dict__.items() if key != "optimizer"
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def get_lr(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_min_info(self):
|
||||
lrs = self.get_lr()
|
||||
return "#LR=[{:.6f}~{:.6f}] epoch={:03d}, iter={:4.2f}#".format(
|
||||
min(lrs), max(lrs), self.current_epoch, self.current_iter
|
||||
)
|
||||
|
||||
def get_min_lr(self):
|
||||
return min(self.get_lr())
|
||||
|
||||
def update(self, cur_epoch, cur_iter):
|
||||
if cur_epoch is not None:
|
||||
assert (
|
||||
isinstance(cur_epoch, int) and cur_epoch >= 0
|
||||
), "invalid cur-epoch : {:}".format(cur_epoch)
|
||||
self.current_epoch = cur_epoch
|
||||
if cur_iter is not None:
|
||||
assert (
|
||||
isinstance(cur_iter, float) and cur_iter >= 0
|
||||
), "invalid cur-iter : {:}".format(cur_iter)
|
||||
self.current_iter = cur_iter
|
||||
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||
param_group["lr"] = lr
|
||||
|
||||
|
||||
class CosineAnnealingLR(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, T_max, eta_min):
|
||||
self.T_max = T_max
|
||||
self.eta_min = eta_min
|
||||
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return "type={:}, T-max={:}, eta-min={:}".format(
|
||||
"cosine", self.T_max, self.eta_min
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if (
|
||||
self.current_epoch >= self.warmup_epochs
|
||||
and self.current_epoch < self.max_epochs
|
||||
):
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
# if last_epoch < self.T_max:
|
||||
# if last_epoch < self.max_epochs:
|
||||
lr = (
|
||||
self.eta_min
|
||||
+ (base_lr - self.eta_min)
|
||||
* (1 + math.cos(math.pi * last_epoch / self.T_max))
|
||||
/ 2
|
||||
)
|
||||
# else:
|
||||
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
|
||||
elif self.current_epoch >= self.max_epochs:
|
||||
lr = self.eta_min
|
||||
else:
|
||||
lr = (
|
||||
self.current_epoch / self.warmup_epochs
|
||||
+ self.current_iter / self.warmup_epochs
|
||||
) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
|
||||
class MultiStepLR(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
|
||||
assert len(milestones) == len(gammas), "invalid {:} vs {:}".format(
|
||||
len(milestones), len(gammas)
|
||||
)
|
||||
self.milestones = milestones
|
||||
self.gammas = gammas
|
||||
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return "type={:}, milestones={:}, gammas={:}, base-lrs={:}".format(
|
||||
"multistep", self.milestones, self.gammas, self.base_lrs
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
idx = bisect_right(self.milestones, last_epoch)
|
||||
lr = base_lr
|
||||
for x in self.gammas[:idx]:
|
||||
lr *= x
|
||||
else:
|
||||
lr = (
|
||||
self.current_epoch / self.warmup_epochs
|
||||
+ self.current_iter / self.warmup_epochs
|
||||
) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
|
||||
class ExponentialLR(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, gamma):
|
||||
self.gamma = gamma
|
||||
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return "type={:}, gamma={:}, base-lrs={:}".format(
|
||||
"exponential", self.gamma, self.base_lrs
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
|
||||
lr = base_lr * (self.gamma**last_epoch)
|
||||
else:
|
||||
lr = (
|
||||
self.current_epoch / self.warmup_epochs
|
||||
+ self.current_iter / self.warmup_epochs
|
||||
) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
|
||||
class LinearLR(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_epochs, epochs, max_LR, min_LR):
|
||||
self.max_LR = max_LR
|
||||
self.min_LR = min_LR
|
||||
super(LinearLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return "type={:}, max_LR={:}, min_LR={:}, base-lrs={:}".format(
|
||||
"LinearLR", self.max_LR, self.min_LR, self.base_lrs
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
lrs = []
|
||||
for base_lr in self.base_lrs:
|
||||
if self.current_epoch >= self.warmup_epochs:
|
||||
last_epoch = self.current_epoch - self.warmup_epochs
|
||||
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
|
||||
ratio = (
|
||||
(self.max_LR - self.min_LR)
|
||||
* last_epoch
|
||||
/ self.max_epochs
|
||||
/ self.max_LR
|
||||
)
|
||||
lr = base_lr * (1 - ratio)
|
||||
else:
|
||||
lr = (
|
||||
self.current_epoch / self.warmup_epochs
|
||||
+ self.current_iter / self.warmup_epochs
|
||||
) * base_lr
|
||||
lrs.append(lr)
|
||||
return lrs
|
||||
|
||||
|
||||
class CrossEntropyLabelSmooth(nn.Module):
|
||||
def __init__(self, num_classes, epsilon):
|
||||
super(CrossEntropyLabelSmooth, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.epsilon = epsilon
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
log_probs = self.logsoftmax(inputs)
|
||||
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
|
||||
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
loss = (-targets * log_probs).mean(0).sum()
|
||||
return loss
|
||||
|
||||
|
||||
def get_optim_scheduler(parameters, config):
|
||||
assert (
|
||||
hasattr(config, "optim")
|
||||
and hasattr(config, "scheduler")
|
||||
and hasattr(config, "criterion")
|
||||
), "config must have optim / scheduler / criterion keys instead of {:}".format(
|
||||
config
|
||||
)
|
||||
if config.optim == "SGD":
|
||||
optim = torch.optim.SGD(
|
||||
parameters,
|
||||
config.LR,
|
||||
momentum=config.momentum,
|
||||
weight_decay=config.decay,
|
||||
nesterov=config.nesterov,
|
||||
)
|
||||
elif config.optim == "RMSprop":
|
||||
optim = torch.optim.RMSprop(
|
||||
parameters, config.LR, momentum=config.momentum, weight_decay=config.decay
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid optim : {:}".format(config.optim))
|
||||
|
||||
if config.scheduler == "cos":
|
||||
T_max = getattr(config, "T_max", config.epochs)
|
||||
scheduler = CosineAnnealingLR(
|
||||
optim, config.warmup, config.epochs, T_max, config.eta_min
|
||||
)
|
||||
elif config.scheduler == "multistep":
|
||||
scheduler = MultiStepLR(
|
||||
optim, config.warmup, config.epochs, config.milestones, config.gammas
|
||||
)
|
||||
elif config.scheduler == "exponential":
|
||||
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
|
||||
elif config.scheduler == "linear":
|
||||
scheduler = LinearLR(
|
||||
optim, config.warmup, config.epochs, config.LR, config.LR_min
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid scheduler : {:}".format(config.scheduler))
|
||||
|
||||
if config.criterion == "Softmax":
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
elif config.criterion == "SmoothSoftmax":
|
||||
criterion = CrossEntropyLabelSmooth(config.class_num, config.label_smooth)
|
||||
else:
|
||||
raise ValueError("invalid criterion : {:}".format(config.criterion))
|
||||
return optim, scheduler, criterion
|
||||
150
AutoDL-Projects/xautodl/procedures/q_exps.py
Normal file
150
AutoDL-Projects/xautodl/procedures/q_exps.py
Normal file
@@ -0,0 +1,150 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
|
||||
#####################################################
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import pprint
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
|
||||
import qlib
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.utils import flatten_dict
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
|
||||
def set_log_basic_config(filename=None, format=None, level=None):
|
||||
"""
|
||||
Set the basic configuration for the logging system.
|
||||
See details at https://docs.python.org/3/library/logging.html#logging.basicConfig
|
||||
:param filename: str or None
|
||||
The path to save the logs.
|
||||
:param format: the logging format
|
||||
:param level: int
|
||||
:return: Logger
|
||||
Logger object.
|
||||
"""
|
||||
from qlib.config import C
|
||||
|
||||
if level is None:
|
||||
level = C.logging_level
|
||||
|
||||
if format is None:
|
||||
format = C.logging_config["formatters"]["logger_format"]["format"]
|
||||
|
||||
# Remove all handlers associated with the root logger object.
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
logging.basicConfig(filename=filename, format=format, level=level)
|
||||
|
||||
|
||||
def update_gpu(config, gpu):
|
||||
config = deepcopy(config)
|
||||
if "task" in config and "model" in config["task"]:
|
||||
if "GPU" in config["task"]["model"]:
|
||||
config["task"]["model"]["GPU"] = gpu
|
||||
elif (
|
||||
"kwargs" in config["task"]["model"]
|
||||
and "GPU" in config["task"]["model"]["kwargs"]
|
||||
):
|
||||
config["task"]["model"]["kwargs"]["GPU"] = gpu
|
||||
elif "model" in config:
|
||||
if "GPU" in config["model"]:
|
||||
config["model"]["GPU"] = gpu
|
||||
elif "kwargs" in config["model"] and "GPU" in config["model"]["kwargs"]:
|
||||
config["model"]["kwargs"]["GPU"] = gpu
|
||||
elif "kwargs" in config and "GPU" in config["kwargs"]:
|
||||
config["kwargs"]["GPU"] = gpu
|
||||
elif "GPU" in config:
|
||||
config["GPU"] = gpu
|
||||
return config
|
||||
|
||||
|
||||
def update_market(config, market):
|
||||
config = deepcopy(config.copy())
|
||||
config["market"] = market
|
||||
config["data_handler_config"]["instruments"] = market
|
||||
return config
|
||||
|
||||
|
||||
def run_exp(
|
||||
task_config,
|
||||
dataset,
|
||||
experiment_name,
|
||||
recorder_name,
|
||||
uri,
|
||||
model_obj_name="model.pkl",
|
||||
):
|
||||
|
||||
model = init_instance_by_config(task_config["model"])
|
||||
model_fit_kwargs = dict(dataset=dataset)
|
||||
|
||||
# Let's start the experiment.
|
||||
with R.start(
|
||||
experiment_name=experiment_name,
|
||||
recorder_name=recorder_name,
|
||||
uri=uri,
|
||||
resume=True,
|
||||
):
|
||||
# Setup log
|
||||
recorder_root_dir = R.get_recorder().get_local_dir()
|
||||
log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name))
|
||||
|
||||
set_log_basic_config(log_file)
|
||||
logger = get_module_logger("q.run_exp")
|
||||
logger.info("task_config::\n{:}".format(pprint.pformat(task_config, indent=2)))
|
||||
logger.info("[{:}] - [{:}]: {:}".format(experiment_name, recorder_name, uri))
|
||||
logger.info("dataset={:}".format(dataset))
|
||||
|
||||
# Train model
|
||||
try:
|
||||
if hasattr(model, "to"): # Recoverable model
|
||||
ori_device = model.device
|
||||
model = R.load_object(model_obj_name)
|
||||
model.to(ori_device)
|
||||
else:
|
||||
model = R.load_object(model_obj_name)
|
||||
logger.info("[Find existing object from {:}]".format(model_obj_name))
|
||||
except OSError:
|
||||
R.log_params(**flatten_dict(update_gpu(task_config, None)))
|
||||
if "save_path" in inspect.getfullargspec(model.fit).args:
|
||||
model_fit_kwargs["save_path"] = os.path.join(
|
||||
recorder_root_dir, "model.ckp"
|
||||
)
|
||||
elif "save_dir" in inspect.getfullargspec(model.fit).args:
|
||||
model_fit_kwargs["save_dir"] = os.path.join(
|
||||
recorder_root_dir, "model-ckps"
|
||||
)
|
||||
model.fit(**model_fit_kwargs)
|
||||
# remove model to CPU for saving
|
||||
if hasattr(model, "to"):
|
||||
old_device = model.device
|
||||
model.to("cpu")
|
||||
R.save_objects(**{model_obj_name: model})
|
||||
model.to(old_device)
|
||||
else:
|
||||
R.save_objects(**{model_obj_name: model})
|
||||
except Exception as e:
|
||||
raise ValueError("Something wrong: {:}".format(e))
|
||||
# Get the recorder
|
||||
recorder = R.get_recorder()
|
||||
|
||||
# Generate records: prediction, backtest, and analysis
|
||||
for record in task_config["record"]:
|
||||
record = deepcopy(record)
|
||||
if record["class"] == "MultiSegRecord":
|
||||
record["kwargs"] = dict(model=model, dataset=dataset, recorder=recorder)
|
||||
sr = init_instance_by_config(record)
|
||||
sr.generate(**record["generate_kwargs"])
|
||||
elif record["class"] == "SignalRecord":
|
||||
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
|
||||
record["kwargs"].update(srconf)
|
||||
sr = init_instance_by_config(record)
|
||||
sr.generate()
|
||||
else:
|
||||
rconf = {"recorder": recorder}
|
||||
record["kwargs"].update(rconf)
|
||||
ar = init_instance_by_config(record)
|
||||
ar.generate()
|
||||
199
AutoDL-Projects/xautodl/procedures/search_main.py
Normal file
199
AutoDL-Projects/xautodl/procedures/search_main.py
Normal file
@@ -0,0 +1,199 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
|
||||
from xautodl.log_utils import AverageMeter, time_string
|
||||
from xautodl.models import change_key
|
||||
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
|
||||
expected_flop = torch.mean(expected_flop)
|
||||
|
||||
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
|
||||
loss = -torch.log(expected_flop)
|
||||
# elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
|
||||
elif flop_cur > flop_need: # Too Large FLOP
|
||||
loss = torch.log(expected_flop)
|
||||
else: # Required FLOP
|
||||
loss = None
|
||||
if loss is None:
|
||||
return 0, 0
|
||||
else:
|
||||
return loss, loss.item()
|
||||
|
||||
|
||||
def search_train(
|
||||
search_loader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
base_optimizer,
|
||||
arch_optimizer,
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, arch_losses, top1, top5 = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
)
|
||||
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
|
||||
epoch_str, flop_need, flop_weight, flop_tolerant = (
|
||||
extra_info["epoch-str"],
|
||||
extra_info["FLOP-exp"],
|
||||
extra_info["FLOP-weight"],
|
||||
extra_info["FLOP-tolerant"],
|
||||
)
|
||||
|
||||
network.train()
|
||||
logger.log(
|
||||
"[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(
|
||||
epoch_str, flop_need, flop_weight
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
network.apply(change_key("search_mode", "search"))
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
|
||||
search_loader
|
||||
):
|
||||
scheduler.update(None, 1.0 * step / len(search_loader))
|
||||
# calculate prediction and loss
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# update the weights
|
||||
base_optimizer.zero_grad()
|
||||
logits, expected_flop = network(base_inputs)
|
||||
# network.apply( change_key('search_mode', 'basic') )
|
||||
# features, logits = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
base_optimizer.step()
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||
top1.update(prec1.item(), base_inputs.size(0))
|
||||
top5.update(prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
logits, expected_flop = network(arch_inputs)
|
||||
flop_cur = network.module.get_flop("genotype", None, None)
|
||||
flop_loss, flop_loss_scale = get_flop_loss(
|
||||
expected_flop, flop_cur, flop_need, flop_tolerant
|
||||
)
|
||||
acls_loss = criterion(logits, arch_targets)
|
||||
arch_loss = acls_loss + flop_loss * flop_weight
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# record
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
|
||||
arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
if step % print_freq == 0 or (step + 1) == len(search_loader):
|
||||
Sstr = (
|
||||
"**TRAIN** "
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
|
||||
)
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Lstr = "Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
|
||||
loss=base_losses, top1=top1, top5=top5
|
||||
)
|
||||
Vstr = "Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})".format(
|
||||
aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses
|
||||
)
|
||||
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Vstr)
|
||||
# Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
|
||||
# logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
|
||||
# print(network.module.get_arch_info())
|
||||
# print(network.module.width_attentions[0])
|
||||
# print(network.module.width_attentions[1])
|
||||
|
||||
logger.log(
|
||||
" **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}".format(
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
error1=100 - top1.avg,
|
||||
error5=100 - top5.avg,
|
||||
baseloss=base_losses.avg,
|
||||
archloss=arch_losses.avg,
|
||||
)
|
||||
)
|
||||
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
|
||||
|
||||
|
||||
def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
|
||||
data_time, batch_time, losses, top1, top5 = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
)
|
||||
|
||||
network.eval()
|
||||
network.apply(change_key("search_mode", "search"))
|
||||
end = time.time()
|
||||
# logger.log('Starting evaluating {:}'.format(epoch_info))
|
||||
with torch.no_grad():
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
logits, expected_flop = network(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
top5.update(prec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i + 1) == len(xloader):
|
||||
Sstr = (
|
||||
"**VALID** "
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
|
||||
)
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
|
||||
loss=losses, top1=top1, top5=top5
|
||||
)
|
||||
Istr = "Size={:}".format(list(inputs.size()))
|
||||
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr)
|
||||
|
||||
logger.log(
|
||||
" **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
error1=100 - top1.avg,
|
||||
error5=100 - top5.avg,
|
||||
loss=losses.avg,
|
||||
)
|
||||
)
|
||||
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
139
AutoDL-Projects/xautodl/procedures/search_main_v2.py
Normal file
139
AutoDL-Projects/xautodl/procedures/search_main_v2.py
Normal file
@@ -0,0 +1,139 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
|
||||
# modules in AutoDL
|
||||
from xautodl.log_utils import AverageMeter, time_string
|
||||
from xautodl.models import change_key
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
|
||||
expected_flop = torch.mean(expected_flop)
|
||||
|
||||
if flop_cur < flop_need - flop_tolerant: # Too Small FLOP
|
||||
loss = -torch.log(expected_flop)
|
||||
# elif flop_cur > flop_need + flop_tolerant: # Too Large FLOP
|
||||
elif flop_cur > flop_need: # Too Large FLOP
|
||||
loss = torch.log(expected_flop)
|
||||
else: # Required FLOP
|
||||
loss = None
|
||||
if loss is None:
|
||||
return 0, 0
|
||||
else:
|
||||
return loss, loss.item()
|
||||
|
||||
|
||||
def search_train_v2(
|
||||
search_loader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
base_optimizer,
|
||||
arch_optimizer,
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, arch_losses, top1, top5 = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
)
|
||||
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
|
||||
epoch_str, flop_need, flop_weight, flop_tolerant = (
|
||||
extra_info["epoch-str"],
|
||||
extra_info["FLOP-exp"],
|
||||
extra_info["FLOP-weight"],
|
||||
extra_info["FLOP-tolerant"],
|
||||
)
|
||||
|
||||
network.train()
|
||||
logger.log(
|
||||
"[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(
|
||||
epoch_str, flop_need, flop_weight
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
network.apply(change_key("search_mode", "search"))
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
|
||||
search_loader
|
||||
):
|
||||
scheduler.update(None, 1.0 * step / len(search_loader))
|
||||
# calculate prediction and loss
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
arch_targets = arch_targets.cuda(non_blocking=True)
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# update the weights
|
||||
base_optimizer.zero_grad()
|
||||
logits, expected_flop = network(base_inputs)
|
||||
base_loss = criterion(logits, base_targets)
|
||||
base_loss.backward()
|
||||
base_optimizer.step()
|
||||
# record
|
||||
prec1, prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
|
||||
base_losses.update(base_loss.item(), base_inputs.size(0))
|
||||
top1.update(prec1.item(), base_inputs.size(0))
|
||||
top5.update(prec5.item(), base_inputs.size(0))
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
logits, expected_flop = network(arch_inputs)
|
||||
flop_cur = network.module.get_flop("genotype", None, None)
|
||||
flop_loss, flop_loss_scale = get_flop_loss(
|
||||
expected_flop, flop_cur, flop_need, flop_tolerant
|
||||
)
|
||||
acls_loss = criterion(logits, arch_targets)
|
||||
arch_loss = acls_loss + flop_loss * flop_weight
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# record
|
||||
arch_losses.update(arch_loss.item(), arch_inputs.size(0))
|
||||
arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
|
||||
arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
if step % print_freq == 0 or (step + 1) == len(search_loader):
|
||||
Sstr = (
|
||||
"**TRAIN** "
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
|
||||
)
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Lstr = "Base-Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
|
||||
loss=base_losses, top1=top1, top5=top5
|
||||
)
|
||||
Vstr = "Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})".format(
|
||||
aloss=arch_cls_losses, floss=arch_flop_losses, loss=arch_losses
|
||||
)
|
||||
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Vstr)
|
||||
# num_bytes = torch.cuda.max_memory_allocated( next(network.parameters()).device ) * 1.0
|
||||
# logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' GPU={:.2f}MB'.format(num_bytes/1e6))
|
||||
# Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
|
||||
# logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
|
||||
# print(network.module.get_arch_info())
|
||||
# print(network.module.width_attentions[0])
|
||||
# print(network.module.width_attentions[1])
|
||||
|
||||
logger.log(
|
||||
" **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}".format(
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
error1=100 - top1.avg,
|
||||
error5=100 - top5.avg,
|
||||
baseloss=base_losses.avg,
|
||||
archloss=arch_losses.avg,
|
||||
)
|
||||
)
|
||||
return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
|
||||
204
AutoDL-Projects/xautodl/procedures/simple_KD_main.py
Normal file
204
AutoDL-Projects/xautodl/procedures/simple_KD_main.py
Normal file
@@ -0,0 +1,204 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
import os, sys, time, torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# modules in AutoDL
|
||||
from xautodl.log_utils import AverageMeter, time_string
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
def simple_KD_train(
|
||||
xloader,
|
||||
teacher,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
loss, acc1, acc5 = procedure(
|
||||
xloader,
|
||||
teacher,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
"train",
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def simple_KD_valid(
|
||||
xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger
|
||||
):
|
||||
with torch.no_grad():
|
||||
loss, acc1, acc5 = procedure(
|
||||
xloader,
|
||||
teacher,
|
||||
network,
|
||||
criterion,
|
||||
None,
|
||||
None,
|
||||
"valid",
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def loss_KD_fn(
|
||||
criterion,
|
||||
student_logits,
|
||||
teacher_logits,
|
||||
studentFeatures,
|
||||
teacherFeatures,
|
||||
targets,
|
||||
alpha,
|
||||
temperature,
|
||||
):
|
||||
basic_loss = criterion(student_logits, targets) * (1.0 - alpha)
|
||||
log_student = F.log_softmax(student_logits / temperature, dim=1)
|
||||
sof_teacher = F.softmax(teacher_logits / temperature, dim=1)
|
||||
KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * (
|
||||
alpha * temperature * temperature
|
||||
)
|
||||
return basic_loss + KD_loss
|
||||
|
||||
|
||||
def procedure(
|
||||
xloader,
|
||||
teacher,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
mode,
|
||||
config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
data_time, batch_time, losses, top1, top5 = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
)
|
||||
Ttop1, Ttop5 = AverageMeter(), AverageMeter()
|
||||
if mode == "train":
|
||||
network.train()
|
||||
elif mode == "valid":
|
||||
network.eval()
|
||||
else:
|
||||
raise ValueError("The mode is not right : {:}".format(mode))
|
||||
teacher.eval()
|
||||
|
||||
logger.log(
|
||||
"[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]".format(
|
||||
mode,
|
||||
config.auxiliary if hasattr(config, "auxiliary") else -1,
|
||||
config.KD_alpha,
|
||||
config.KD_temperature,
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
if mode == "train":
|
||||
scheduler.update(None, 1.0 * i / len(xloader))
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
# calculate prediction and loss
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
if mode == "train":
|
||||
optimizer.zero_grad()
|
||||
|
||||
student_f, logits = network(inputs)
|
||||
if isinstance(logits, list):
|
||||
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(
|
||||
2, len(logits)
|
||||
)
|
||||
logits, logits_aux = logits
|
||||
else:
|
||||
logits, logits_aux = logits, None
|
||||
with torch.no_grad():
|
||||
teacher_f, teacher_logits = teacher(inputs)
|
||||
|
||||
loss = loss_KD_fn(
|
||||
criterion,
|
||||
logits,
|
||||
teacher_logits,
|
||||
student_f,
|
||||
teacher_f,
|
||||
targets,
|
||||
config.KD_alpha,
|
||||
config.KD_temperature,
|
||||
)
|
||||
if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0:
|
||||
loss_aux = criterion(logits_aux, targets)
|
||||
loss += config.auxiliary * loss_aux
|
||||
|
||||
if mode == "train":
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# record
|
||||
sprec1, sprec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(sprec1.item(), inputs.size(0))
|
||||
top5.update(sprec5.item(), inputs.size(0))
|
||||
# teacher
|
||||
tprec1, tprec5 = obtain_accuracy(teacher_logits.data, targets.data, topk=(1, 5))
|
||||
Ttop1.update(tprec1.item(), inputs.size(0))
|
||||
Ttop5.update(tprec5.item(), inputs.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i + 1) == len(xloader):
|
||||
Sstr = (
|
||||
" {:5s} ".format(mode.upper())
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
|
||||
)
|
||||
if scheduler is not None:
|
||||
Sstr += " {:}".format(scheduler.get_min_info())
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
|
||||
loss=losses, top1=top1, top5=top5
|
||||
)
|
||||
Lstr += " Teacher : acc@1={:.2f}, acc@5={:.2f}".format(Ttop1.avg, Ttop5.avg)
|
||||
Istr = "Size={:}".format(list(inputs.size()))
|
||||
logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr)
|
||||
|
||||
logger.log(
|
||||
" **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}".format(
|
||||
mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg
|
||||
)
|
||||
)
|
||||
logger.log(
|
||||
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
||||
mode=mode.upper(),
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
error1=100 - top1.avg,
|
||||
error5=100 - top5.avg,
|
||||
loss=losses.avg,
|
||||
)
|
||||
)
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
79
AutoDL-Projects/xautodl/procedures/starts.py
Normal file
79
AutoDL-Projects/xautodl/procedures/starts.py
Normal file
@@ -0,0 +1,79 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, torch, random, PIL, copy, numpy as np
|
||||
from os import path as osp
|
||||
from shutil import copyfile
|
||||
|
||||
|
||||
def prepare_seed(rand_seed):
|
||||
random.seed(rand_seed)
|
||||
np.random.seed(rand_seed)
|
||||
torch.manual_seed(rand_seed)
|
||||
torch.cuda.manual_seed(rand_seed)
|
||||
torch.cuda.manual_seed_all(rand_seed)
|
||||
|
||||
|
||||
def prepare_logger(xargs):
|
||||
args = copy.deepcopy(xargs)
|
||||
from xautodl.log_utils import Logger
|
||||
|
||||
logger = Logger(args.save_dir, args.rand_seed)
|
||||
logger.log("Main Function with logger : {:}".format(logger))
|
||||
logger.log("Arguments : -------------------------------")
|
||||
for name, value in args._get_kwargs():
|
||||
logger.log("{:16} : {:}".format(name, value))
|
||||
logger.log("Python Version : {:}".format(sys.version.replace("\n", " ")))
|
||||
logger.log("Pillow Version : {:}".format(PIL.__version__))
|
||||
logger.log("PyTorch Version : {:}".format(torch.__version__))
|
||||
logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
|
||||
logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
|
||||
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
|
||||
logger.log(
|
||||
"CUDA_VISIBLE_DEVICES : {:}".format(
|
||||
os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ
|
||||
else "None"
|
||||
)
|
||||
)
|
||||
return logger
|
||||
|
||||
|
||||
def get_machine_info():
|
||||
info = "Python Version : {:}".format(sys.version.replace("\n", " "))
|
||||
info += "\nPillow Version : {:}".format(PIL.__version__)
|
||||
info += "\nPyTorch Version : {:}".format(torch.__version__)
|
||||
info += "\ncuDNN Version : {:}".format(torch.backends.cudnn.version())
|
||||
info += "\nCUDA available : {:}".format(torch.cuda.is_available())
|
||||
info += "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count())
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
info += "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ["CUDA_VISIBLE_DEVICES"])
|
||||
else:
|
||||
info += "\nDoes not set CUDA_VISIBLE_DEVICES"
|
||||
return info
|
||||
|
||||
|
||||
def save_checkpoint(state, filename, logger):
|
||||
if osp.isfile(filename):
|
||||
if hasattr(logger, "log"):
|
||||
logger.log(
|
||||
"Find {:} exist, delete is at first before saving".format(filename)
|
||||
)
|
||||
os.remove(filename)
|
||||
torch.save(state, filename)
|
||||
assert osp.isfile(
|
||||
filename
|
||||
), "save filename : {:} failed, which is not found.".format(filename)
|
||||
if hasattr(logger, "log"):
|
||||
logger.log("save checkpoint into {:}".format(filename))
|
||||
return filename
|
||||
|
||||
|
||||
def copy_checkpoint(src, dst, logger):
|
||||
if osp.isfile(dst):
|
||||
if hasattr(logger, "log"):
|
||||
logger.log("Find {:} exist, delete is at first before saving".format(dst))
|
||||
os.remove(dst)
|
||||
copyfile(src, dst)
|
||||
if hasattr(logger, "log"):
|
||||
logger.log("copy the file from {:} into {:}".format(src, dst))
|
||||
17
AutoDL-Projects/xautodl/spaces/__init__.py
Normal file
17
AutoDL-Projects/xautodl/spaces/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 #
|
||||
#####################################################
|
||||
# Define complex searc space for AutoDL #
|
||||
#####################################################
|
||||
|
||||
from .basic_space import Categorical
|
||||
from .basic_space import Continuous
|
||||
from .basic_space import Integer
|
||||
from .basic_space import Space
|
||||
from .basic_space import VirtualNode
|
||||
from .basic_op import has_categorical
|
||||
from .basic_op import has_continuous
|
||||
from .basic_op import is_determined
|
||||
from .basic_op import get_determined_value
|
||||
from .basic_op import get_min
|
||||
from .basic_op import get_max
|
||||
74
AutoDL-Projects/xautodl/spaces/basic_op.py
Normal file
74
AutoDL-Projects/xautodl/spaces/basic_op.py
Normal file
@@ -0,0 +1,74 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
from .basic_space import Space
|
||||
from .basic_space import VirtualNode
|
||||
from .basic_space import Integer
|
||||
from .basic_space import Continuous
|
||||
from .basic_space import Categorical
|
||||
from .basic_space import _EPS
|
||||
|
||||
|
||||
def has_categorical(space_or_value, x):
|
||||
if isinstance(space_or_value, Space):
|
||||
return space_or_value.has(x)
|
||||
else:
|
||||
return space_or_value == x
|
||||
|
||||
|
||||
def has_continuous(space_or_value, x):
|
||||
if isinstance(space_or_value, Space):
|
||||
return space_or_value.has(x)
|
||||
else:
|
||||
return abs(space_or_value - x) <= _EPS
|
||||
|
||||
|
||||
def is_determined(space_or_value):
|
||||
if isinstance(space_or_value, Space):
|
||||
return space_or_value.determined
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def get_determined_value(space_or_value):
|
||||
if not is_determined(space_or_value):
|
||||
raise ValueError("This input is not determined: {:}".format(space_or_value))
|
||||
if isinstance(space_or_value, Space):
|
||||
if isinstance(space_or_value, Continuous):
|
||||
return space_or_value.lower
|
||||
elif isinstance(space_or_value, Categorical):
|
||||
return get_determined_value(space_or_value[0])
|
||||
else: # VirtualNode
|
||||
return space_or_value.value
|
||||
else:
|
||||
return space_or_value
|
||||
|
||||
|
||||
def get_max(space_or_value):
|
||||
if isinstance(space_or_value, Integer):
|
||||
return max(space_or_value.candidates)
|
||||
elif isinstance(space_or_value, Continuous):
|
||||
return space_or_value.upper
|
||||
elif isinstance(space_or_value, Categorical):
|
||||
values = []
|
||||
for index in range(len(space_or_value)):
|
||||
max_value = get_max(space_or_value[index])
|
||||
values.append(max_value)
|
||||
return max(values)
|
||||
else:
|
||||
return space_or_value
|
||||
|
||||
|
||||
def get_min(space_or_value):
|
||||
if isinstance(space_or_value, Integer):
|
||||
return min(space_or_value.candidates)
|
||||
elif isinstance(space_or_value, Continuous):
|
||||
return space_or_value.lower
|
||||
elif isinstance(space_or_value, Categorical):
|
||||
values = []
|
||||
for index in range(len(space_or_value)):
|
||||
min_value = get_min(space_or_value[index])
|
||||
values.append(min_value)
|
||||
return min(values)
|
||||
else:
|
||||
return space_or_value
|
||||
434
AutoDL-Projects/xautodl/spaces/basic_space.py
Normal file
434
AutoDL-Projects/xautodl/spaces/basic_space.py
Normal file
@@ -0,0 +1,434 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
|
||||
import abc
|
||||
import math
|
||||
import copy
|
||||
import random
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
from typing import Optional, Text
|
||||
|
||||
|
||||
__all__ = ["_EPS", "Space", "Categorical", "Integer", "Continuous"]
|
||||
|
||||
_EPS = 1e-9
|
||||
|
||||
|
||||
class Space(metaclass=abc.ABCMeta):
|
||||
"""Basic search space describing the set of possible candidate values for hyperparameter.
|
||||
All search space must inherit from this basic class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# used to avoid duplicate sample
|
||||
self._last_sample = None
|
||||
self._last_abstract = None
|
||||
|
||||
@abc.abstractproperty
|
||||
def xrepr(self, depth=0) -> Text:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self) -> Text:
|
||||
return self.xrepr()
|
||||
|
||||
@abc.abstractproperty
|
||||
def abstract(self, reuse_last=False) -> "Space":
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def random(self, recursion=True, reuse_last=False):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def clean_last_sample(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def clean_last_abstract(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def clean_last(self):
|
||||
self.clean_last_sample()
|
||||
self.clean_last_abstract()
|
||||
|
||||
@abc.abstractproperty
|
||||
def determined(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def has(self, x) -> bool:
|
||||
"""Check whether x is in this search space."""
|
||||
assert not isinstance(
|
||||
x, Space
|
||||
), "The input value itself can not be a search space."
|
||||
|
||||
@abc.abstractmethod
|
||||
def __eq__(self, other):
|
||||
raise NotImplementedError
|
||||
|
||||
def copy(self) -> "Space":
|
||||
return copy.deepcopy(self)
|
||||
|
||||
|
||||
class VirtualNode(Space):
|
||||
"""For a nested search space, we represent it as a tree structure.
|
||||
|
||||
For example,
|
||||
"""
|
||||
|
||||
def __init__(self, id=None, value=None):
|
||||
super(VirtualNode, self).__init__()
|
||||
self._id = id
|
||||
self._value = value
|
||||
self._attributes = OrderedDict()
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self._value
|
||||
|
||||
def append(self, key, value):
|
||||
if not isinstance(key, str):
|
||||
raise TypeError(
|
||||
"Only accept string as a key instead of {:}".format(type(key))
|
||||
)
|
||||
if not isinstance(value, Space):
|
||||
raise ValueError("Invalid type of value: {:}".format(type(value)))
|
||||
# if value.determined:
|
||||
# raise ValueError("Can not attach a determined value: {:}".format(value))
|
||||
self._attributes[key] = value
|
||||
|
||||
def xrepr(self, depth=0) -> Text:
|
||||
strs = [self.__class__.__name__ + "(value={:}".format(self._value)]
|
||||
for key, value in self._attributes.items():
|
||||
strs.append(key + " = " + value.xrepr(depth + 1))
|
||||
strs.append(")")
|
||||
if len(strs) == 2:
|
||||
return "".join(strs)
|
||||
else:
|
||||
space = " "
|
||||
xstrs = (
|
||||
[strs[0]]
|
||||
+ [space * (depth + 1) + x for x in strs[1:-1]]
|
||||
+ [space * depth + strs[-1]]
|
||||
)
|
||||
return ",\n".join(xstrs)
|
||||
|
||||
def abstract(self, reuse_last=False) -> Space:
|
||||
if reuse_last and self._last_abstract is not None:
|
||||
return self._last_abstract
|
||||
node = VirtualNode(id(self))
|
||||
for key, value in self._attributes.items():
|
||||
if not value.determined:
|
||||
node.append(value.abstract(reuse_last))
|
||||
self._last_abstract = node
|
||||
return self._last_abstract
|
||||
|
||||
def random(self, recursion=True, reuse_last=False):
|
||||
if reuse_last and self._last_sample is not None:
|
||||
return self._last_sample
|
||||
node = VirtualNode(None, self._value)
|
||||
for key, value in self._attributes.items():
|
||||
node.append(key, value.random(recursion, reuse_last))
|
||||
self._last_sample = node # record the last sample
|
||||
return node
|
||||
|
||||
def clean_last_sample(self):
|
||||
self._last_sample = None
|
||||
for key, value in self._attributes.items():
|
||||
value.clean_last_sample()
|
||||
|
||||
def clean_last_abstract(self):
|
||||
self._last_abstract = None
|
||||
for key, value in self._attributes.items():
|
||||
value.clean_last_abstract()
|
||||
|
||||
def has(self, x) -> bool:
|
||||
for key, value in self._attributes.items():
|
||||
if value.has(x):
|
||||
return True
|
||||
return False
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self._attributes
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._attributes[key]
|
||||
|
||||
@property
|
||||
def determined(self) -> bool:
|
||||
for key, value in self._attributes.items():
|
||||
if not value.determined:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, VirtualNode):
|
||||
return False
|
||||
for key, value in self._attributes.items():
|
||||
if not key in other:
|
||||
return False
|
||||
if value != other[key]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class Categorical(Space):
|
||||
"""A space contains the categorical values.
|
||||
It can be a nested space, which means that the candidate in this space can also be a search space.
|
||||
"""
|
||||
|
||||
def __init__(self, *data, default: Optional[int] = None):
|
||||
super(Categorical, self).__init__()
|
||||
self._candidates = [*data]
|
||||
self._default = default
|
||||
assert self._default is None or 0 <= self._default < len(
|
||||
self._candidates
|
||||
), "default >= {:}".format(len(self._candidates))
|
||||
assert len(self) > 0, "Please provide at least one candidate"
|
||||
|
||||
@property
|
||||
def candidates(self):
|
||||
return self._candidates
|
||||
|
||||
@property
|
||||
def default(self):
|
||||
return self._default
|
||||
|
||||
@property
|
||||
def determined(self):
|
||||
if len(self) == 1:
|
||||
return (
|
||||
not isinstance(self._candidates[0], Space)
|
||||
or self._candidates[0].determined
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self._candidates[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._candidates)
|
||||
|
||||
def clean_last_sample(self):
|
||||
self._last_sample = None
|
||||
for candidate in self._candidates:
|
||||
if isinstance(candidate, Space):
|
||||
candidate.clean_last_sample()
|
||||
|
||||
def clean_last_abstract(self):
|
||||
self._last_abstract = None
|
||||
for candidate in self._candidates:
|
||||
if isinstance(candidate, Space):
|
||||
candidate.clean_last_abstract()
|
||||
|
||||
def abstract(self, reuse_last=False) -> Space:
|
||||
if reuse_last and self._last_abstract is not None:
|
||||
return self._last_abstract
|
||||
if self.determined:
|
||||
result = VirtualNode(id(self), self)
|
||||
else:
|
||||
# [TO-IMPROVE]
|
||||
data = []
|
||||
for candidate in self.candidates:
|
||||
if isinstance(candidate, Space):
|
||||
data.append(candidate.abstract())
|
||||
else:
|
||||
data.append(VirtualNode(id(candidate), candidate))
|
||||
result = Categorical(*data, default=self._default)
|
||||
self._last_abstract = result
|
||||
return self._last_abstract
|
||||
|
||||
def random(self, recursion=True, reuse_last=False):
|
||||
if reuse_last and self._last_sample is not None:
|
||||
return self._last_sample
|
||||
sample = random.choice(self._candidates)
|
||||
if recursion and isinstance(sample, Space):
|
||||
sample = sample.random(recursion, reuse_last)
|
||||
if isinstance(sample, VirtualNode):
|
||||
sample = sample.copy()
|
||||
else:
|
||||
sample = VirtualNode(None, sample)
|
||||
self._last_sample = sample
|
||||
return self._last_sample
|
||||
|
||||
def xrepr(self, depth=0):
|
||||
del depth
|
||||
xrepr = "{name:}(candidates={cs:}, default_index={default:})".format(
|
||||
name=self.__class__.__name__, cs=self._candidates, default=self._default
|
||||
)
|
||||
return xrepr
|
||||
|
||||
def has(self, x):
|
||||
super().has(x)
|
||||
for candidate in self._candidates:
|
||||
if isinstance(candidate, Space) and candidate.has(x):
|
||||
return True
|
||||
elif candidate == x:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Categorical):
|
||||
return False
|
||||
if len(self) != len(other):
|
||||
return False
|
||||
if self.default != other.default:
|
||||
return False
|
||||
for index in range(len(self)):
|
||||
if self.__getitem__(index) != other[index]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class Integer(Categorical):
|
||||
"""A space contains the integer values."""
|
||||
|
||||
def __init__(self, lower: int, upper: int, default: Optional[int] = None):
|
||||
if not isinstance(lower, int) or not isinstance(upper, int):
|
||||
raise ValueError(
|
||||
"The lower [{:}] and uppwer [{:}] must be int.".format(lower, upper)
|
||||
)
|
||||
data = list(range(lower, upper + 1))
|
||||
self._raw_lower = lower
|
||||
self._raw_upper = upper
|
||||
self._raw_default = default
|
||||
if default is not None and (default < lower or default > upper):
|
||||
raise ValueError("The default value [{:}] is out of range.".format(default))
|
||||
default = data.index(default)
|
||||
super(Integer, self).__init__(*data, default=default)
|
||||
|
||||
def xrepr(self, depth=0):
|
||||
del depth
|
||||
xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format(
|
||||
name=self.__class__.__name__,
|
||||
lower=self._raw_lower,
|
||||
upper=self._raw_upper,
|
||||
default=self._raw_default,
|
||||
)
|
||||
return xrepr
|
||||
|
||||
|
||||
np_float_types = (np.float16, np.float32, np.float64)
|
||||
np_int_types = (
|
||||
np.uint8,
|
||||
np.int8,
|
||||
np.uint16,
|
||||
np.int16,
|
||||
np.uint32,
|
||||
np.int32,
|
||||
np.uint64,
|
||||
np.int64,
|
||||
)
|
||||
|
||||
|
||||
class Continuous(Space):
|
||||
"""A space contains the continuous values."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lower: float,
|
||||
upper: float,
|
||||
default: Optional[float] = None,
|
||||
log: bool = False,
|
||||
eps: float = _EPS,
|
||||
):
|
||||
super(Continuous, self).__init__()
|
||||
self._lower = lower
|
||||
self._upper = upper
|
||||
self._default = default
|
||||
self._log_scale = log
|
||||
self._eps = eps
|
||||
|
||||
@property
|
||||
def lower(self):
|
||||
return self._lower
|
||||
|
||||
@property
|
||||
def upper(self):
|
||||
return self._upper
|
||||
|
||||
@property
|
||||
def default(self):
|
||||
return self._default
|
||||
|
||||
@property
|
||||
def use_log(self):
|
||||
return self._log_scale
|
||||
|
||||
@property
|
||||
def eps(self):
|
||||
return self._eps
|
||||
|
||||
def abstract(self, reuse_last=False) -> Space:
|
||||
if reuse_last and self._last_abstract is not None:
|
||||
return self._last_abstract
|
||||
self._last_abstract = self.copy()
|
||||
return self._last_abstract
|
||||
|
||||
def random(self, recursion=True, reuse_last=False):
|
||||
del recursion
|
||||
if reuse_last and self._last_sample is not None:
|
||||
return self._last_sample
|
||||
if self._log_scale:
|
||||
sample = random.uniform(math.log(self._lower), math.log(self._upper))
|
||||
sample = math.exp(sample)
|
||||
else:
|
||||
sample = random.uniform(self._lower, self._upper)
|
||||
self._last_sample = VirtualNode(None, sample)
|
||||
return self._last_sample
|
||||
|
||||
def xrepr(self, depth=0):
|
||||
del depth
|
||||
xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format(
|
||||
name=self.__class__.__name__,
|
||||
lower=self._lower,
|
||||
upper=self._upper,
|
||||
default=self._default,
|
||||
log=self._log_scale,
|
||||
)
|
||||
return xrepr
|
||||
|
||||
def convert(self, x):
|
||||
if isinstance(x, np_float_types) and x.size == 1:
|
||||
return float(x), True
|
||||
elif isinstance(x, np_int_types) and x.size == 1:
|
||||
return float(x), True
|
||||
elif isinstance(x, int):
|
||||
return float(x), True
|
||||
elif isinstance(x, float):
|
||||
return float(x), True
|
||||
else:
|
||||
return None, False
|
||||
|
||||
def has(self, x):
|
||||
super().has(x)
|
||||
converted_x, success = self.convert(x)
|
||||
return success and self.lower <= converted_x <= self.upper
|
||||
|
||||
@property
|
||||
def determined(self):
|
||||
return abs(self.lower - self.upper) <= self._eps
|
||||
|
||||
def clean_last_sample(self):
|
||||
self._last_sample = None
|
||||
|
||||
def clean_last_abstract(self):
|
||||
self._last_abstract = None
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Continuous):
|
||||
return False
|
||||
if self is other:
|
||||
return True
|
||||
else:
|
||||
return (
|
||||
self.lower == other.lower
|
||||
and self.upper == other.upper
|
||||
and self.default == other.default
|
||||
and self.use_log == other.use_log
|
||||
and self.eps == other.eps
|
||||
)
|
||||
4
AutoDL-Projects/xautodl/trade_models/__init__.py
Normal file
4
AutoDL-Projects/xautodl/trade_models/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
from .transformers import get_transformer
|
||||
102
AutoDL-Projects/xautodl/trade_models/naive_v1_model.py
Normal file
102
AutoDL-Projects/xautodl/trade_models/naive_v1_model.py
Normal file
@@ -0,0 +1,102 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021 #
|
||||
##################################################
|
||||
# Use noise as prediction #
|
||||
##################################################
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
from qlib.model.base import Model
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class NAIVE_V1(Model):
|
||||
"""NAIVE Version 1 Quant Model"""
|
||||
|
||||
def __init__(self, d_feat=6, seed=None, **kwargs):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("NAIVE")
|
||||
self.logger.info("NAIVE 1st version: random noise ...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"NAIVE-V1 parameters setting: d_feat={:}, seed={:}".format(
|
||||
self.d_feat, self.seed
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
random.seed(self.seed)
|
||||
np.random.seed(self.seed)
|
||||
self._mean = None
|
||||
self._std = None
|
||||
self.fitted = False
|
||||
|
||||
def process_data(self, features):
|
||||
features = features.reshape(len(features), self.d_feat, -1)
|
||||
features = features.transpose((0, 2, 1))
|
||||
return features[:, :59, 0]
|
||||
|
||||
def mse(self, preds, labels):
|
||||
masks = ~np.isnan(labels)
|
||||
masked_preds = preds[masks]
|
||||
masked_labels = labels[masks]
|
||||
return np.square(masked_preds - masked_labels).mean()
|
||||
|
||||
def model(self, x):
|
||||
num = len(x)
|
||||
return np.random.normal(loc=self._mean, scale=self._std, size=num).astype(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
def fit(self, dataset: DatasetH):
|
||||
def _prepare_dataset(df_data):
|
||||
features = df_data["feature"].values
|
||||
features = self.process_data(features)
|
||||
labels = df_data["label"].values.squeeze()
|
||||
return dict(features=features, labels=labels)
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
train_dataset, valid_dataset, test_dataset = (
|
||||
_prepare_dataset(df_train),
|
||||
_prepare_dataset(df_valid),
|
||||
_prepare_dataset(df_test),
|
||||
)
|
||||
# df_train['feature']['CLOSE1'].values
|
||||
# train_dataset['features'][:, -1]
|
||||
masks = ~np.isnan(train_dataset["labels"])
|
||||
self._mean, self._std = np.mean(train_dataset["labels"][masks]), np.std(
|
||||
train_dataset["labels"][masks]
|
||||
)
|
||||
train_mse_loss = self.mse(
|
||||
self.model(train_dataset["features"]), train_dataset["labels"]
|
||||
)
|
||||
valid_mse_loss = self.mse(
|
||||
self.model(valid_dataset["features"]), valid_dataset["labels"]
|
||||
)
|
||||
self.logger.info("Training MSE loss: {:}".format(train_mse_loss))
|
||||
self.logger.info("Validation MSE loss: {:}".format(valid_mse_loss))
|
||||
self.fitted = True
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self.fitted:
|
||||
raise ValueError("The model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
index = x_test.index
|
||||
|
||||
preds = self.model(self.process_data(x_test.values))
|
||||
return pd.Series(preds, index=index)
|
||||
103
AutoDL-Projects/xautodl/trade_models/naive_v2_model.py
Normal file
103
AutoDL-Projects/xautodl/trade_models/naive_v2_model.py
Normal file
@@ -0,0 +1,103 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021 #
|
||||
##################################################
|
||||
# A Simple Model that reused the prices of last day
|
||||
##################################################
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
from qlib.model.base import Model
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class NAIVE_V2(Model):
|
||||
"""NAIVE Version 2 Quant Model"""
|
||||
|
||||
def __init__(self, d_feat=6, seed=None, **kwargs):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("NAIVE")
|
||||
self.logger.info("NAIVE version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"NAIVE parameters setting: d_feat={:}, seed={:}".format(
|
||||
self.d_feat, self.seed
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
random.seed(self.seed)
|
||||
np.random.seed(self.seed)
|
||||
|
||||
self.fitted = False
|
||||
|
||||
def process_data(self, features):
|
||||
features = features.reshape(len(features), self.d_feat, -1)
|
||||
features = features.transpose((0, 2, 1))
|
||||
return features[:, :59, 0]
|
||||
|
||||
def mse(self, preds, labels):
|
||||
masks = ~np.isnan(labels)
|
||||
masked_preds = preds[masks]
|
||||
masked_labels = labels[masks]
|
||||
return np.square(masked_preds - masked_labels).mean()
|
||||
|
||||
def model(self, x):
|
||||
x = 1 / x - 1
|
||||
masks = ~np.isnan(x)
|
||||
results = []
|
||||
for rowd, rowm in zip(x, masks):
|
||||
temp = rowd[rowm]
|
||||
if rowm.any():
|
||||
results.append(float(rowd[rowm][-1]))
|
||||
else:
|
||||
results.append(0)
|
||||
return np.array(results, dtype=x.dtype)
|
||||
|
||||
def fit(self, dataset: DatasetH):
|
||||
def _prepare_dataset(df_data):
|
||||
features = df_data["feature"].values
|
||||
features = self.process_data(features)
|
||||
labels = df_data["label"].values.squeeze()
|
||||
return dict(features=features, labels=labels)
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
train_dataset, valid_dataset, test_dataset = (
|
||||
_prepare_dataset(df_train),
|
||||
_prepare_dataset(df_valid),
|
||||
_prepare_dataset(df_test),
|
||||
)
|
||||
# df_train['feature']['CLOSE1'].values
|
||||
# train_dataset['features'][:, -1]
|
||||
train_mse_loss = self.mse(
|
||||
self.model(train_dataset["features"]), train_dataset["labels"]
|
||||
)
|
||||
valid_mse_loss = self.mse(
|
||||
self.model(valid_dataset["features"]), valid_dataset["labels"]
|
||||
)
|
||||
self.logger.info("Training MSE loss: {:}".format(train_mse_loss))
|
||||
self.logger.info("Validation MSE loss: {:}".format(valid_mse_loss))
|
||||
self.fitted = True
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self.fitted:
|
||||
raise ValueError("The model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
index = x_test.index
|
||||
|
||||
preds = self.model(self.process_data(x_test.values))
|
||||
return pd.Series(preds, index=index)
|
||||
358
AutoDL-Projects/xautodl/trade_models/quant_transformer.py
Normal file
358
AutoDL-Projects/xautodl/trade_models/quant_transformer.py
Normal file
@@ -0,0 +1,358 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021 #
|
||||
##################################################
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os, math, random
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from functools import partial
|
||||
from typing import Optional, Text
|
||||
|
||||
from qlib.utils import get_or_create_path
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.utils.data as th_data
|
||||
|
||||
from xautodl.xmisc import AverageMeter
|
||||
from xautodl.xmisc import count_parameters
|
||||
|
||||
from xautodl.xlayers import super_core
|
||||
from .transformers import DEFAULT_NET_CONFIG
|
||||
from .transformers import get_transformer
|
||||
|
||||
|
||||
from qlib.model.base import Model
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
DEFAULT_OPT_CONFIG = dict(
|
||||
epochs=200,
|
||||
lr=0.001,
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
num_workers=4,
|
||||
)
|
||||
|
||||
|
||||
def train_or_test_epoch(
|
||||
xloader, model, loss_fn, metric_fn, is_train, optimizer, device
|
||||
):
|
||||
if is_train:
|
||||
model.train()
|
||||
else:
|
||||
model.eval()
|
||||
score_meter, loss_meter = AverageMeter(), AverageMeter()
|
||||
for ibatch, (feats, labels) in enumerate(xloader):
|
||||
feats, labels = feats.to(device), labels.to(device)
|
||||
# forward the network
|
||||
preds = model(feats)
|
||||
loss = loss_fn(preds, labels)
|
||||
with torch.no_grad():
|
||||
score = metric_fn(preds, labels)
|
||||
loss_meter.update(loss.item(), feats.size(0))
|
||||
score_meter.update(score.item(), feats.size(0))
|
||||
# optimize the network
|
||||
if is_train and optimizer is not None:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(model.parameters(), 3.0)
|
||||
optimizer.step()
|
||||
return loss_meter.avg, score_meter.avg
|
||||
|
||||
|
||||
class QuantTransformer(Model):
|
||||
"""Transformer-based Quant Model"""
|
||||
|
||||
def __init__(
|
||||
self, net_config=None, opt_config=None, metric="", GPU=0, seed=None, **kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("QuantTransformer")
|
||||
self.logger.info("QuantTransformer PyTorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.net_config = net_config or DEFAULT_NET_CONFIG
|
||||
self.opt_config = opt_config or DEFAULT_OPT_CONFIG
|
||||
self.metric = metric
|
||||
self.device = torch.device(
|
||||
"cuda:{:}".format(GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu"
|
||||
)
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"Transformer parameters setting:"
|
||||
"\nnet_config : {:}"
|
||||
"\nopt_config : {:}"
|
||||
"\nmetric : {:}"
|
||||
"\ndevice : {:}"
|
||||
"\nseed : {:}".format(
|
||||
self.net_config,
|
||||
self.opt_config,
|
||||
self.metric,
|
||||
self.device,
|
||||
self.seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
random.seed(self.seed)
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
if self.use_gpu:
|
||||
torch.cuda.manual_seed(self.seed)
|
||||
torch.cuda.manual_seed_all(self.seed)
|
||||
|
||||
self.model = get_transformer(self.net_config)
|
||||
self.model.set_super_run_type(super_core.SuperRunMode.FullModel)
|
||||
self.logger.info("model: {:}".format(self.model))
|
||||
self.logger.info("model size: {:.3f} MB".format(count_parameters(self.model)))
|
||||
|
||||
if self.opt_config["optimizer"] == "adam":
|
||||
self.train_optimizer = optim.Adam(
|
||||
self.model.parameters(), lr=self.opt_config["lr"]
|
||||
)
|
||||
elif self.opt_config["optimizer"] == "adam":
|
||||
self.train_optimizer = optim.SGD(
|
||||
self.model.parameters(), lr=self.opt_config["lr"]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"optimizer {:} is not supported!".format(optimizer)
|
||||
)
|
||||
|
||||
self.fitted = False
|
||||
self.model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def to(self, device):
|
||||
if device is None:
|
||||
device = "cpu"
|
||||
self.device = device
|
||||
self.model.to(self.device)
|
||||
# move the optimizer
|
||||
for param in self.train_optimizer.state.values():
|
||||
# Not sure there are any global tensors in the state dict
|
||||
if isinstance(param, torch.Tensor):
|
||||
param.data = param.data.to(device)
|
||||
if param._grad is not None:
|
||||
param._grad.data = param._grad.data.to(device)
|
||||
elif isinstance(param, dict):
|
||||
for subparam in param.values():
|
||||
if isinstance(subparam, torch.Tensor):
|
||||
subparam.data = subparam.data.to(device)
|
||||
if subparam._grad is not None:
|
||||
subparam._grad.data = subparam._grad.data.to(device)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
if self.opt_config["loss"] == "mse":
|
||||
return F.mse_loss(pred[mask], label[mask])
|
||||
else:
|
||||
raise ValueError("unknown loss `{:}`".format(self.loss))
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
# the metric score : higher is better
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred, label)
|
||||
else:
|
||||
raise ValueError("unknown metric `{:}`".format(self.metric))
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
save_dir: Optional[Text] = None,
|
||||
):
|
||||
def _prepare_dataset(df_data):
|
||||
return th_data.TensorDataset(
|
||||
torch.from_numpy(df_data["feature"].values).float(),
|
||||
torch.from_numpy(df_data["label"].values).squeeze().float(),
|
||||
)
|
||||
|
||||
def _prepare_loader(dataset, shuffle):
|
||||
return th_data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.opt_config["batch_size"],
|
||||
drop_last=False,
|
||||
pin_memory=True,
|
||||
num_workers=self.opt_config["num_workers"],
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
train_dataset, valid_dataset, test_dataset = (
|
||||
_prepare_dataset(df_train),
|
||||
_prepare_dataset(df_valid),
|
||||
_prepare_dataset(df_test),
|
||||
)
|
||||
train_loader, valid_loader, test_loader = (
|
||||
_prepare_loader(train_dataset, True),
|
||||
_prepare_loader(valid_dataset, False),
|
||||
_prepare_loader(test_dataset, False),
|
||||
)
|
||||
|
||||
save_dir = get_or_create_path(save_dir, return_dir=True)
|
||||
self.logger.info(
|
||||
"Fit procedure for [{:}] with save path={:}".format(
|
||||
self.__class__.__name__, save_dir
|
||||
)
|
||||
)
|
||||
|
||||
def _internal_test(ckp_epoch=None, results_dict=None):
|
||||
with torch.no_grad():
|
||||
shared_kwards = {
|
||||
"model": self.model,
|
||||
"loss_fn": self.loss_fn,
|
||||
"metric_fn": self.metric_fn,
|
||||
"is_train": False,
|
||||
"optimizer": None,
|
||||
"device": self.device,
|
||||
}
|
||||
train_loss, train_score = train_or_test_epoch(
|
||||
train_loader, **shared_kwards
|
||||
)
|
||||
valid_loss, valid_score = train_or_test_epoch(
|
||||
valid_loader, **shared_kwards
|
||||
)
|
||||
test_loss, test_score = train_or_test_epoch(
|
||||
test_loader, **shared_kwards
|
||||
)
|
||||
xstr = (
|
||||
"train-score={:.6f}, valid-score={:.6f}, test-score={:.6f}".format(
|
||||
train_score, valid_score, test_score
|
||||
)
|
||||
)
|
||||
if ckp_epoch is not None and isinstance(results_dict, dict):
|
||||
results_dict["train"][ckp_epoch] = train_score
|
||||
results_dict["valid"][ckp_epoch] = valid_score
|
||||
results_dict["test"][ckp_epoch] = test_score
|
||||
return dict(train=train_score, valid=valid_score, test=test_score), xstr
|
||||
|
||||
# Pre-fetch the potential checkpoints
|
||||
ckp_path = os.path.join(save_dir, "{:}.pth".format(self.__class__.__name__))
|
||||
if os.path.exists(ckp_path):
|
||||
ckp_data = torch.load(ckp_path, map_location=self.device)
|
||||
stop_steps, best_score, best_epoch = (
|
||||
ckp_data["stop_steps"],
|
||||
ckp_data["best_score"],
|
||||
ckp_data["best_epoch"],
|
||||
)
|
||||
start_epoch, best_param = ckp_data["start_epoch"], ckp_data["best_param"]
|
||||
results_dict = ckp_data["results_dict"]
|
||||
self.model.load_state_dict(ckp_data["net_state_dict"])
|
||||
self.train_optimizer.load_state_dict(ckp_data["opt_state_dict"])
|
||||
self.logger.info("Resume from existing checkpoint: {:}".format(ckp_path))
|
||||
else:
|
||||
stop_steps, best_score, best_epoch = 0, -np.inf, -1
|
||||
start_epoch, best_param = 0, None
|
||||
results_dict = dict(
|
||||
train=OrderedDict(), valid=OrderedDict(), test=OrderedDict()
|
||||
)
|
||||
_, eval_str = _internal_test(-1, results_dict)
|
||||
self.logger.info(
|
||||
"Training from scratch, metrics@start: {:}".format(eval_str)
|
||||
)
|
||||
|
||||
for iepoch in range(start_epoch, self.opt_config["epochs"]):
|
||||
self.logger.info(
|
||||
"Epoch={:03d}/{:03d} ::==>> Best valid @{:03d} ({:.6f})".format(
|
||||
iepoch, self.opt_config["epochs"], best_epoch, best_score
|
||||
)
|
||||
)
|
||||
train_loss, train_score = train_or_test_epoch(
|
||||
train_loader,
|
||||
self.model,
|
||||
self.loss_fn,
|
||||
self.metric_fn,
|
||||
True,
|
||||
self.train_optimizer,
|
||||
self.device,
|
||||
)
|
||||
self.logger.info(
|
||||
"Training :: loss={:.6f}, score={:.6f}".format(train_loss, train_score)
|
||||
)
|
||||
|
||||
current_eval_scores, eval_str = _internal_test(iepoch, results_dict)
|
||||
self.logger.info("Evaluating :: {:}".format(eval_str))
|
||||
|
||||
if current_eval_scores["valid"] > best_score:
|
||||
stop_steps, best_epoch, best_score = (
|
||||
0,
|
||||
iepoch,
|
||||
current_eval_scores["valid"],
|
||||
)
|
||||
best_param = copy.deepcopy(self.model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.opt_config["early_stop"]:
|
||||
self.logger.info(
|
||||
"early stop at {:}-th epoch, where the best is @{:}".format(
|
||||
iepoch, best_epoch
|
||||
)
|
||||
)
|
||||
break
|
||||
save_info = dict(
|
||||
net_config=self.net_config,
|
||||
opt_config=self.opt_config,
|
||||
net_state_dict=self.model.state_dict(),
|
||||
opt_state_dict=self.train_optimizer.state_dict(),
|
||||
best_param=best_param,
|
||||
stop_steps=stop_steps,
|
||||
best_score=best_score,
|
||||
best_epoch=best_epoch,
|
||||
results_dict=results_dict,
|
||||
start_epoch=iepoch + 1,
|
||||
)
|
||||
torch.save(save_info, ckp_path)
|
||||
self.logger.info(
|
||||
"The best score: {:.6f} @ {:02d}-th epoch".format(best_score, best_epoch)
|
||||
)
|
||||
self.model.load_state_dict(best_param)
|
||||
_, eval_str = _internal_test("final", results_dict)
|
||||
self.logger.info("Reload the best parameter :: {:}".format(eval_str))
|
||||
|
||||
if self.use_gpu:
|
||||
with torch.cuda.device(self.device):
|
||||
torch.cuda.empty_cache()
|
||||
self.fitted = True
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("The model is not fitted yet!")
|
||||
x_test = dataset.prepare(
|
||||
segment, col_set="feature", data_key=DataHandlerLP.DK_I
|
||||
)
|
||||
index = x_test.index
|
||||
|
||||
with torch.no_grad():
|
||||
self.model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num, batch_size = x_values.shape[0], self.opt_config["batch_size"]
|
||||
preds = []
|
||||
for begin in range(sample_num)[::batch_size]:
|
||||
if sample_num - begin < batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + batch_size
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
with torch.no_grad():
|
||||
pred = self.model(x_batch).detach().cpu().numpy()
|
||||
preds.append(pred)
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
199
AutoDL-Projects/xautodl/trade_models/transformers.py
Normal file
199
AutoDL-Projects/xautodl/trade_models/transformers.py
Normal file
@@ -0,0 +1,199 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Optional, Text, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from xautodl import spaces
|
||||
from xautodl.xlayers import weight_init
|
||||
from xautodl.xlayers import super_core
|
||||
|
||||
|
||||
__all__ = ["DefaultSearchSpace", "DEFAULT_NET_CONFIG", "get_transformer"]
|
||||
|
||||
|
||||
def _get_mul_specs(candidates, num):
|
||||
results = []
|
||||
for i in range(num):
|
||||
results.append(spaces.Categorical(*candidates))
|
||||
return results
|
||||
|
||||
|
||||
def _get_list_mul(num, multipler):
|
||||
results = []
|
||||
for i in range(1, num + 1):
|
||||
results.append(i * multipler)
|
||||
return results
|
||||
|
||||
|
||||
def _assert_types(x, expected_types):
|
||||
if not isinstance(x, expected_types):
|
||||
raise TypeError(
|
||||
"The type [{:}] is expected to be {:}.".format(type(x), expected_types)
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_NET_CONFIG = None
|
||||
_default_max_depth = 6
|
||||
DefaultSearchSpace = dict(
|
||||
d_feat=6,
|
||||
embed_dim=32,
|
||||
# embed_dim=spaces.Categorical(*_get_list_mul(8, 16)),
|
||||
num_heads=[4] * _default_max_depth,
|
||||
mlp_hidden_multipliers=[4] * _default_max_depth,
|
||||
qkv_bias=True,
|
||||
pos_drop=0.0,
|
||||
other_drop=0.0,
|
||||
)
|
||||
|
||||
|
||||
class SuperTransformer(super_core.SuperModule):
|
||||
"""The super model for transformer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat: int = 6,
|
||||
embed_dim: List[super_core.IntSpaceType] = DefaultSearchSpace["embed_dim"],
|
||||
num_heads: List[super_core.IntSpaceType] = DefaultSearchSpace["num_heads"],
|
||||
mlp_hidden_multipliers: List[super_core.IntSpaceType] = DefaultSearchSpace[
|
||||
"mlp_hidden_multipliers"
|
||||
],
|
||||
qkv_bias: bool = DefaultSearchSpace["qkv_bias"],
|
||||
pos_drop: float = DefaultSearchSpace["pos_drop"],
|
||||
other_drop: float = DefaultSearchSpace["other_drop"],
|
||||
max_seq_len: int = 65,
|
||||
):
|
||||
super(SuperTransformer, self).__init__()
|
||||
self._embed_dim = embed_dim
|
||||
self._num_heads = num_heads
|
||||
self._mlp_hidden_multipliers = mlp_hidden_multipliers
|
||||
|
||||
# the stem part
|
||||
self.input_embed = super_core.SuperAlphaEBDv1(d_feat, embed_dim)
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
||||
self.pos_embed = super_core.SuperPositionalEncoder(
|
||||
d_model=embed_dim, max_seq_len=max_seq_len, dropout=pos_drop
|
||||
)
|
||||
# build the transformer encode layers -->> check params
|
||||
_assert_types(num_heads, (tuple, list))
|
||||
_assert_types(mlp_hidden_multipliers, (tuple, list))
|
||||
assert len(num_heads) == len(mlp_hidden_multipliers), "{:} vs {:}".format(
|
||||
len(num_heads), len(mlp_hidden_multipliers)
|
||||
)
|
||||
# build the transformer encode layers -->> backbone
|
||||
layers = []
|
||||
for num_head, mlp_hidden_multiplier in zip(num_heads, mlp_hidden_multipliers):
|
||||
layer = super_core.SuperTransformerEncoderLayer(
|
||||
embed_dim,
|
||||
num_head,
|
||||
qkv_bias,
|
||||
mlp_hidden_multiplier,
|
||||
other_drop,
|
||||
)
|
||||
layers.append(layer)
|
||||
self.backbone = super_core.SuperSequential(*layers)
|
||||
|
||||
# the regression head
|
||||
self.head = super_core.SuperSequential(
|
||||
super_core.SuperLayerNorm1D(embed_dim), super_core.SuperLinear(embed_dim, 1)
|
||||
)
|
||||
weight_init.trunc_normal_(self.cls_token, std=0.02)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
@property
|
||||
def embed_dim(self):
|
||||
return spaces.get_max(self._embed_dim)
|
||||
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
root_node = spaces.VirtualNode(id(self))
|
||||
if not spaces.is_determined(self._embed_dim):
|
||||
root_node.append("_embed_dim", self._embed_dim.abstract(reuse_last=True))
|
||||
xdict = dict(
|
||||
input_embed=self.input_embed.abstract_search_space,
|
||||
pos_embed=self.pos_embed.abstract_search_space,
|
||||
backbone=self.backbone.abstract_search_space,
|
||||
head=self.head.abstract_search_space,
|
||||
)
|
||||
for key, space in xdict.items():
|
||||
if not spaces.is_determined(space):
|
||||
root_node.append(key, space)
|
||||
return root_node
|
||||
|
||||
def apply_candidate(self, abstract_child: spaces.VirtualNode):
|
||||
super(SuperTransformer, self).apply_candidate(abstract_child)
|
||||
xkeys = ("input_embed", "pos_embed", "backbone", "head")
|
||||
for key in xkeys:
|
||||
if key in abstract_child:
|
||||
getattr(self, key).apply_candidate(abstract_child[key])
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
weight_init.trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, super_core.SuperLinear):
|
||||
weight_init.trunc_normal_(m._super_weight, std=0.02)
|
||||
if m._super_bias is not None:
|
||||
nn.init.constant_(m._super_bias, 0)
|
||||
elif isinstance(m, super_core.SuperLayerNorm1D):
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||||
batch, flatten_size = input.shape
|
||||
feats = self.input_embed(input) # batch * 60 * 64
|
||||
if not spaces.is_determined(self._embed_dim):
|
||||
embed_dim = self.abstract_child["_embed_dim"].value
|
||||
else:
|
||||
embed_dim = spaces.get_determined_value(self._embed_dim)
|
||||
cls_tokens = self.cls_token.expand(batch, -1, -1)
|
||||
cls_tokens = F.interpolate(
|
||||
cls_tokens, size=(embed_dim), mode="linear", align_corners=True
|
||||
)
|
||||
feats_w_ct = torch.cat((cls_tokens, feats), dim=1)
|
||||
feats_w_tp = self.pos_embed(feats_w_ct)
|
||||
xfeats = self.backbone(feats_w_tp)
|
||||
xfeats = xfeats[:, 0, :] # use the feature for the first token
|
||||
predicts = self.head(xfeats).squeeze(-1)
|
||||
return predicts
|
||||
|
||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||
batch, flatten_size = input.shape
|
||||
feats = self.input_embed(input) # batch * 60 * 64
|
||||
cls_tokens = self.cls_token.expand(batch, -1, -1)
|
||||
feats_w_ct = torch.cat((cls_tokens, feats), dim=1)
|
||||
feats_w_tp = self.pos_embed(feats_w_ct)
|
||||
xfeats = self.backbone(feats_w_tp)
|
||||
xfeats = xfeats[:, 0, :] # use the feature for the first token
|
||||
predicts = self.head(xfeats).squeeze(-1)
|
||||
return predicts
|
||||
|
||||
|
||||
def get_transformer(config):
|
||||
if config is None:
|
||||
return SuperTransformer(6)
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError("Invalid Configuration: {:}".format(config))
|
||||
name = config.get("name", "basic")
|
||||
if name == "basic":
|
||||
model = SuperTransformer(
|
||||
d_feat=config.get("d_feat"),
|
||||
embed_dim=config.get("embed_dim"),
|
||||
num_heads=config.get("num_heads"),
|
||||
mlp_hidden_multipliers=config.get("mlp_hidden_multipliers"),
|
||||
qkv_bias=config.get("qkv_bias"),
|
||||
pos_drop=config.get("pos_drop"),
|
||||
other_drop=config.get("other_drop"),
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown model name: {:}".format(name))
|
||||
return model
|
||||
14
AutoDL-Projects/xautodl/utils/__init__.py
Normal file
14
AutoDL-Projects/xautodl/utils/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||
#####################################################
|
||||
# This directory contains some ad-hoc functions, classes, etc.
|
||||
# It will be re-formulated in the future.
|
||||
#####################################################
|
||||
from .evaluation_utils import obtain_accuracy
|
||||
from .gpu_manager import GPUManager
|
||||
from .flop_benchmark import get_model_infos, count_parameters, count_parameters_in_MB
|
||||
from .affine_utils import normalize_points, denormalize_points
|
||||
from .affine_utils import identity2affine, solve2theta, affine2image
|
||||
from .hash_utils import get_md5_file
|
||||
from .str_utils import split_str2indexes
|
||||
from .str_utils import show_mean_var
|
||||
159
AutoDL-Projects/xautodl/utils/affine_utils.py
Normal file
159
AutoDL-Projects/xautodl/utils/affine_utils.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# functions for affine transformation
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def identity2affine(full=False):
|
||||
if not full:
|
||||
parameters = torch.zeros((2, 3))
|
||||
parameters[0, 0] = parameters[1, 1] = 1
|
||||
else:
|
||||
parameters = torch.zeros((3, 3))
|
||||
parameters[0, 0] = parameters[1, 1] = parameters[2, 2] = 1
|
||||
return parameters
|
||||
|
||||
|
||||
def normalize_L(x, L):
|
||||
return -1.0 + 2.0 * x / (L - 1)
|
||||
|
||||
|
||||
def denormalize_L(x, L):
|
||||
return (x + 1.0) / 2.0 * (L - 1)
|
||||
|
||||
|
||||
def crop2affine(crop_box, W, H):
|
||||
assert len(crop_box) == 4, "Invalid crop-box : {:}".format(crop_box)
|
||||
parameters = torch.zeros(3, 3)
|
||||
x1, y1 = normalize_L(crop_box[0], W), normalize_L(crop_box[1], H)
|
||||
x2, y2 = normalize_L(crop_box[2], W), normalize_L(crop_box[3], H)
|
||||
parameters[0, 0] = (x2 - x1) / 2
|
||||
parameters[0, 2] = (x2 + x1) / 2
|
||||
|
||||
parameters[1, 1] = (y2 - y1) / 2
|
||||
parameters[1, 2] = (y2 + y1) / 2
|
||||
parameters[2, 2] = 1
|
||||
return parameters
|
||||
|
||||
|
||||
def scale2affine(scalex, scaley):
|
||||
parameters = torch.zeros(3, 3)
|
||||
parameters[0, 0] = scalex
|
||||
parameters[1, 1] = scaley
|
||||
parameters[2, 2] = 1
|
||||
return parameters
|
||||
|
||||
|
||||
def offset2affine(offx, offy):
|
||||
parameters = torch.zeros(3, 3)
|
||||
parameters[0, 0] = parameters[1, 1] = parameters[2, 2] = 1
|
||||
parameters[0, 2] = offx
|
||||
parameters[1, 2] = offy
|
||||
return parameters
|
||||
|
||||
|
||||
def horizontalmirror2affine():
|
||||
parameters = torch.zeros(3, 3)
|
||||
parameters[0, 0] = -1
|
||||
parameters[1, 1] = parameters[2, 2] = 1
|
||||
return parameters
|
||||
|
||||
|
||||
# clockwise rotate image = counterclockwise rotate the rectangle
|
||||
# degree is between [0, 360]
|
||||
def rotate2affine(degree):
|
||||
assert degree >= 0 and degree <= 360, "Invalid degree : {:}".format(degree)
|
||||
degree = degree / 180 * math.pi
|
||||
parameters = torch.zeros(3, 3)
|
||||
parameters[0, 0] = math.cos(-degree)
|
||||
parameters[0, 1] = -math.sin(-degree)
|
||||
parameters[1, 0] = math.sin(-degree)
|
||||
parameters[1, 1] = math.cos(-degree)
|
||||
parameters[2, 2] = 1
|
||||
return parameters
|
||||
|
||||
|
||||
# shape is a tuple [H, W]
|
||||
def normalize_points(shape, points):
|
||||
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(
|
||||
shape
|
||||
) == 2, "invalid shape : {:}".format(shape)
|
||||
assert isinstance(points, torch.Tensor) and (
|
||||
points.shape[0] == 2
|
||||
), "points are wrong : {:}".format(points.shape)
|
||||
(H, W), points = shape, points.clone()
|
||||
points[0, :] = normalize_L(points[0, :], W)
|
||||
points[1, :] = normalize_L(points[1, :], H)
|
||||
return points
|
||||
|
||||
|
||||
# shape is a tuple [H, W]
|
||||
def normalize_points_batch(shape, points):
|
||||
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(
|
||||
shape
|
||||
) == 2, "invalid shape : {:}".format(shape)
|
||||
assert isinstance(points, torch.Tensor) and (
|
||||
points.size(-1) == 2
|
||||
), "points are wrong : {:}".format(points.shape)
|
||||
(H, W), points = shape, points.clone()
|
||||
x = normalize_L(points[..., 0], W)
|
||||
y = normalize_L(points[..., 1], H)
|
||||
return torch.stack((x, y), dim=-1)
|
||||
|
||||
|
||||
# shape is a tuple [H, W]
|
||||
def denormalize_points(shape, points):
|
||||
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(
|
||||
shape
|
||||
) == 2, "invalid shape : {:}".format(shape)
|
||||
assert isinstance(points, torch.Tensor) and (
|
||||
points.shape[0] == 2
|
||||
), "points are wrong : {:}".format(points.shape)
|
||||
(H, W), points = shape, points.clone()
|
||||
points[0, :] = denormalize_L(points[0, :], W)
|
||||
points[1, :] = denormalize_L(points[1, :], H)
|
||||
return points
|
||||
|
||||
|
||||
# shape is a tuple [H, W]
|
||||
def denormalize_points_batch(shape, points):
|
||||
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(
|
||||
shape
|
||||
) == 2, "invalid shape : {:}".format(shape)
|
||||
assert isinstance(points, torch.Tensor) and (
|
||||
points.shape[-1] == 2
|
||||
), "points are wrong : {:}".format(points.shape)
|
||||
(H, W), points = shape, points.clone()
|
||||
x = denormalize_L(points[..., 0], W)
|
||||
y = denormalize_L(points[..., 1], H)
|
||||
return torch.stack((x, y), dim=-1)
|
||||
|
||||
|
||||
# make target * theta = source
|
||||
def solve2theta(source, target):
|
||||
source, target = source.clone(), target.clone()
|
||||
oks = source[2, :] == 1
|
||||
assert torch.sum(oks).item() >= 3, "valid points : {:} is short".format(oks)
|
||||
if target.size(0) == 2:
|
||||
target = torch.cat((target, oks.unsqueeze(0).float()), dim=0)
|
||||
source, target = source[:, oks], target[:, oks]
|
||||
source, target = source.transpose(1, 0), target.transpose(1, 0)
|
||||
assert source.size(1) == target.size(1) == 3
|
||||
# X, residual, rank, s = np.linalg.lstsq(target.numpy(), source.numpy())
|
||||
# theta = torch.Tensor(X.T[:2, :])
|
||||
X_, qr = torch.gels(source, target)
|
||||
theta = X_[:3, :2].transpose(1, 0)
|
||||
return theta
|
||||
|
||||
|
||||
# shape = [H,W]
|
||||
def affine2image(image, theta, shape):
|
||||
C, H, W = image.size()
|
||||
theta = theta[:2, :].unsqueeze(0)
|
||||
grid_size = torch.Size([1, C, shape[0], shape[1]])
|
||||
grid = F.affine_grid(theta, grid_size)
|
||||
affI = F.grid_sample(
|
||||
image.unsqueeze(0), grid, mode="bilinear", padding_mode="border"
|
||||
)
|
||||
return affI.squeeze(0)
|
||||
17
AutoDL-Projects/xautodl/utils/evaluation_utils.py
Normal file
17
AutoDL-Projects/xautodl/utils/evaluation_utils.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
|
||||
def obtain_accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
227
AutoDL-Projects/xautodl/utils/flop_benchmark.py
Normal file
227
AutoDL-Projects/xautodl/utils/flop_benchmark.py
Normal file
@@ -0,0 +1,227 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 #
|
||||
#####################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
def count_parameters_in_MB(model):
|
||||
return count_parameters(model, "mb", deprecated=True)
|
||||
|
||||
|
||||
def count_parameters(model_or_parameters, unit="mb", deprecated=False):
|
||||
if isinstance(model_or_parameters, nn.Module):
|
||||
counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters())
|
||||
elif isinstance(model_or_parameters, nn.Parameter):
|
||||
counts = model_or_parameters.numel()
|
||||
elif isinstance(model_or_parameters, (list, tuple)):
|
||||
counts = sum(
|
||||
count_parameters(x, None, deprecated) for x in model_or_parameters
|
||||
)
|
||||
else:
|
||||
counts = sum(np.prod(v.size()) for v in model_or_parameters)
|
||||
if not isinstance(unit, str) and unit is not None:
|
||||
raise ValueError("Unknow type of unit: {:}".format(unit))
|
||||
elif unit is None:
|
||||
counts = counts
|
||||
elif unit.lower() == "kb" or unit.lower() == "k":
|
||||
counts /= 1e3 if deprecated else 2 ** 10 # changed from 1e3 to 2^10
|
||||
elif unit.lower() == "mb" or unit.lower() == "m":
|
||||
counts /= 1e6 if deprecated else 2 ** 20 # changed from 1e6 to 2^20
|
||||
elif unit.lower() == "gb" or unit.lower() == "g":
|
||||
counts /= 1e9 if deprecated else 2 ** 30 # changed from 1e9 to 2^30
|
||||
else:
|
||||
raise ValueError("Unknow unit: {:}".format(unit))
|
||||
return counts
|
||||
|
||||
|
||||
def get_model_infos(model, shape):
|
||||
# model = copy.deepcopy( model )
|
||||
|
||||
model = add_flops_counting_methods(model)
|
||||
# model = model.cuda()
|
||||
model.eval()
|
||||
|
||||
# cache_inputs = torch.zeros(*shape).cuda()
|
||||
# cache_inputs = torch.zeros(*shape)
|
||||
cache_inputs = torch.rand(*shape)
|
||||
if next(model.parameters()).is_cuda:
|
||||
cache_inputs = cache_inputs.cuda()
|
||||
# print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log)
|
||||
with torch.no_grad():
|
||||
_____ = model(cache_inputs)
|
||||
FLOPs = compute_average_flops_cost(model) / 1e6
|
||||
Param = count_parameters_in_MB(model)
|
||||
|
||||
if hasattr(model, "auxiliary_param"):
|
||||
aux_params = count_parameters_in_MB(model.auxiliary_param())
|
||||
print("The auxiliary params of this model is : {:}".format(aux_params))
|
||||
print(
|
||||
"We remove the auxiliary params from the total params ({:}) when counting".format(
|
||||
Param
|
||||
)
|
||||
)
|
||||
Param = Param - aux_params
|
||||
|
||||
# print_log('FLOPs : {:} MB'.format(FLOPs), log)
|
||||
torch.cuda.empty_cache()
|
||||
model.apply(remove_hook_function)
|
||||
return FLOPs, Param
|
||||
|
||||
|
||||
# ---- Public functions
|
||||
def add_flops_counting_methods(model):
|
||||
model.__batch_counter__ = 0
|
||||
add_batch_counter_hook_function(model)
|
||||
model.apply(add_flops_counter_variable_or_reset)
|
||||
model.apply(add_flops_counter_hook_function)
|
||||
return model
|
||||
|
||||
|
||||
def compute_average_flops_cost(model):
|
||||
"""
|
||||
A method that will be available after add_flops_counting_methods() is called on a desired net object.
|
||||
Returns current mean flops consumption per image.
|
||||
"""
|
||||
batches_count = model.__batch_counter__
|
||||
flops_sum = 0
|
||||
# or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \
|
||||
for module in model.modules():
|
||||
if (
|
||||
isinstance(module, torch.nn.Conv2d)
|
||||
or isinstance(module, torch.nn.Linear)
|
||||
or isinstance(module, torch.nn.Conv1d)
|
||||
or hasattr(module, "calculate_flop_self")
|
||||
):
|
||||
flops_sum += module.__flops__
|
||||
return flops_sum / batches_count
|
||||
|
||||
|
||||
# ---- Internal functions
|
||||
def pool_flops_counter_hook(pool_module, inputs, output):
|
||||
batch_size = inputs[0].size(0)
|
||||
kernel_size = pool_module.kernel_size
|
||||
out_C, output_height, output_width = output.shape[1:]
|
||||
assert out_C == inputs[0].size(1), "{:} vs. {:}".format(out_C, inputs[0].size())
|
||||
|
||||
overall_flops = (
|
||||
batch_size * out_C * output_height * output_width * kernel_size * kernel_size
|
||||
)
|
||||
pool_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def self_calculate_flops_counter_hook(self_module, inputs, output):
|
||||
overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape)
|
||||
self_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def fc_flops_counter_hook(fc_module, inputs, output):
|
||||
batch_size = inputs[0].size(0)
|
||||
xin, xout = fc_module.in_features, fc_module.out_features
|
||||
assert xin == inputs[0].size(1) and xout == output.size(1), "IO=({:}, {:})".format(
|
||||
xin, xout
|
||||
)
|
||||
overall_flops = batch_size * xin * xout
|
||||
if fc_module.bias is not None:
|
||||
overall_flops += batch_size * xout
|
||||
fc_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def conv1d_flops_counter_hook(conv_module, inputs, outputs):
|
||||
batch_size = inputs[0].size(0)
|
||||
outL = outputs.shape[-1]
|
||||
[kernel] = conv_module.kernel_size
|
||||
in_channels = conv_module.in_channels
|
||||
out_channels = conv_module.out_channels
|
||||
groups = conv_module.groups
|
||||
conv_per_position_flops = kernel * in_channels * out_channels / groups
|
||||
|
||||
active_elements_count = batch_size * outL
|
||||
overall_flops = conv_per_position_flops * active_elements_count
|
||||
|
||||
if conv_module.bias is not None:
|
||||
overall_flops += out_channels * active_elements_count
|
||||
conv_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def conv2d_flops_counter_hook(conv_module, inputs, output):
|
||||
batch_size = inputs[0].size(0)
|
||||
output_height, output_width = output.shape[2:]
|
||||
|
||||
kernel_height, kernel_width = conv_module.kernel_size
|
||||
in_channels = conv_module.in_channels
|
||||
out_channels = conv_module.out_channels
|
||||
groups = conv_module.groups
|
||||
conv_per_position_flops = (
|
||||
kernel_height * kernel_width * in_channels * out_channels / groups
|
||||
)
|
||||
|
||||
active_elements_count = batch_size * output_height * output_width
|
||||
overall_flops = conv_per_position_flops * active_elements_count
|
||||
|
||||
if conv_module.bias is not None:
|
||||
overall_flops += out_channels * active_elements_count
|
||||
conv_module.__flops__ += overall_flops
|
||||
|
||||
|
||||
def batch_counter_hook(module, inputs, output):
|
||||
# Can have multiple inputs, getting the first one
|
||||
inputs = inputs[0]
|
||||
batch_size = inputs.shape[0]
|
||||
module.__batch_counter__ += batch_size
|
||||
|
||||
|
||||
def add_batch_counter_hook_function(module):
|
||||
if not hasattr(module, "__batch_counter_handle__"):
|
||||
handle = module.register_forward_hook(batch_counter_hook)
|
||||
module.__batch_counter_handle__ = handle
|
||||
|
||||
|
||||
def add_flops_counter_variable_or_reset(module):
|
||||
if (
|
||||
isinstance(module, torch.nn.Conv2d)
|
||||
or isinstance(module, torch.nn.Linear)
|
||||
or isinstance(module, torch.nn.Conv1d)
|
||||
or isinstance(module, torch.nn.AvgPool2d)
|
||||
or isinstance(module, torch.nn.MaxPool2d)
|
||||
or hasattr(module, "calculate_flop_self")
|
||||
):
|
||||
module.__flops__ = 0
|
||||
|
||||
|
||||
def add_flops_counter_hook_function(module):
|
||||
if isinstance(module, torch.nn.Conv2d):
|
||||
if not hasattr(module, "__flops_handle__"):
|
||||
handle = module.register_forward_hook(conv2d_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif isinstance(module, torch.nn.Conv1d):
|
||||
if not hasattr(module, "__flops_handle__"):
|
||||
handle = module.register_forward_hook(conv1d_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif isinstance(module, torch.nn.Linear):
|
||||
if not hasattr(module, "__flops_handle__"):
|
||||
handle = module.register_forward_hook(fc_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif isinstance(module, torch.nn.AvgPool2d) or isinstance(
|
||||
module, torch.nn.MaxPool2d
|
||||
):
|
||||
if not hasattr(module, "__flops_handle__"):
|
||||
handle = module.register_forward_hook(pool_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
elif hasattr(module, "calculate_flop_self"): # self-defined module
|
||||
if not hasattr(module, "__flops_handle__"):
|
||||
handle = module.register_forward_hook(self_calculate_flops_counter_hook)
|
||||
module.__flops_handle__ = handle
|
||||
|
||||
|
||||
def remove_hook_function(module):
|
||||
hookers = ["__batch_counter_handle__", "__flops_handle__"]
|
||||
for hooker in hookers:
|
||||
if hasattr(module, hooker):
|
||||
handle = getattr(module, hooker)
|
||||
handle.remove()
|
||||
keys = ["__flops__", "__batch_counter__", "__flops__"] + hookers
|
||||
for ckey in keys:
|
||||
if hasattr(module, ckey):
|
||||
delattr(module, ckey)
|
||||
86
AutoDL-Projects/xautodl/utils/gpu_manager.py
Normal file
86
AutoDL-Projects/xautodl/utils/gpu_manager.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import os
|
||||
|
||||
|
||||
class GPUManager:
|
||||
queries = (
|
||||
"index",
|
||||
"gpu_name",
|
||||
"memory.free",
|
||||
"memory.used",
|
||||
"memory.total",
|
||||
"power.draw",
|
||||
"power.limit",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
all_gpus = self.query_gpu(False)
|
||||
|
||||
def get_info(self, ctype):
|
||||
cmd = "nvidia-smi --query-gpu={} --format=csv,noheader".format(ctype)
|
||||
lines = os.popen(cmd).readlines()
|
||||
lines = [line.strip("\n") for line in lines]
|
||||
return lines
|
||||
|
||||
def query_gpu(self, show=True):
|
||||
num_gpus = len(self.get_info("index"))
|
||||
all_gpus = [{} for i in range(num_gpus)]
|
||||
for query in self.queries:
|
||||
infos = self.get_info(query)
|
||||
for idx, info in enumerate(infos):
|
||||
all_gpus[idx][query] = info
|
||||
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
CUDA_VISIBLE_DEVICES = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
||||
selected_gpus = []
|
||||
for idx, CUDA_VISIBLE_DEVICE in enumerate(CUDA_VISIBLE_DEVICES):
|
||||
find = False
|
||||
for gpu in all_gpus:
|
||||
if gpu["index"] == CUDA_VISIBLE_DEVICE:
|
||||
assert not find, "Duplicate cuda device index : {}".format(
|
||||
CUDA_VISIBLE_DEVICE
|
||||
)
|
||||
find = True
|
||||
selected_gpus.append(gpu.copy())
|
||||
selected_gpus[-1]["index"] = "{}".format(idx)
|
||||
assert find, "Does not find the device : {}".format(CUDA_VISIBLE_DEVICE)
|
||||
all_gpus = selected_gpus
|
||||
|
||||
if show:
|
||||
allstrings = ""
|
||||
for gpu in all_gpus:
|
||||
string = "| "
|
||||
for query in self.queries:
|
||||
if query.find("memory") == 0:
|
||||
xinfo = "{:>9}".format(gpu[query])
|
||||
else:
|
||||
xinfo = gpu[query]
|
||||
string = string + query + " : " + xinfo + " | "
|
||||
allstrings = allstrings + string + "\n"
|
||||
return allstrings
|
||||
else:
|
||||
return all_gpus
|
||||
|
||||
def select_by_memory(self, numbers=1):
|
||||
all_gpus = self.query_gpu(False)
|
||||
assert numbers <= len(all_gpus), "Require {} gpus more than you have".format(
|
||||
numbers
|
||||
)
|
||||
alls = []
|
||||
for idx, gpu in enumerate(all_gpus):
|
||||
free_memory = gpu["memory.free"]
|
||||
free_memory = free_memory.split(" ")[0]
|
||||
free_memory = int(free_memory)
|
||||
index = gpu["index"]
|
||||
alls.append((free_memory, index))
|
||||
alls.sort(reverse=True)
|
||||
alls = [int(alls[i][1]) for i in range(numbers)]
|
||||
return sorted(alls)
|
||||
|
||||
|
||||
"""
|
||||
if __name__ == '__main__':
|
||||
manager = GPUManager()
|
||||
manager.query_gpu(True)
|
||||
indexes = manager.select_by_memory(3)
|
||||
print (indexes)
|
||||
"""
|
||||
17
AutoDL-Projects/xautodl/utils/hash_utils.py
Normal file
17
AutoDL-Projects/xautodl/utils/hash_utils.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
|
||||
def get_md5_file(file_path, post_truncated=5):
|
||||
md5_hash = hashlib.md5()
|
||||
if os.path.exists(file_path):
|
||||
xfile = open(file_path, "rb")
|
||||
content = xfile.read()
|
||||
md5_hash.update(content)
|
||||
digest = md5_hash.hexdigest()
|
||||
else:
|
||||
raise ValueError("[get_md5_file] {:} does not exist".format(file_path))
|
||||
if post_truncated is None:
|
||||
return digest
|
||||
else:
|
||||
return digest[-post_truncated:]
|
||||
76
AutoDL-Projects/xautodl/utils/nas_utils.py
Normal file
76
AutoDL-Projects/xautodl/utils/nas_utils.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# This file is for experimental usage
|
||||
import torch, random
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch.nn as nn
|
||||
|
||||
# modules in AutoDL
|
||||
from models import CellStructure
|
||||
from log_utils import time_string
|
||||
|
||||
|
||||
def evaluate_one_shot(model, xloader, api, cal_mode, seed=111):
|
||||
print(
|
||||
"This is an old version of codes to use NAS-Bench-API, and should be modified to align with the new version. Please contact me for more details if you use this function."
|
||||
)
|
||||
weights = deepcopy(model.state_dict())
|
||||
model.train(cal_mode)
|
||||
with torch.no_grad():
|
||||
logits = nn.functional.log_softmax(model.arch_parameters, dim=-1)
|
||||
archs = CellStructure.gen_all(model.op_names, model.max_nodes, False)
|
||||
probs, accuracies, gt_accs_10_valid, gt_accs_10_test = [], [], [], []
|
||||
loader_iter = iter(xloader)
|
||||
random.seed(seed)
|
||||
random.shuffle(archs)
|
||||
for idx, arch in enumerate(archs):
|
||||
arch_index = api.query_index_by_arch(arch)
|
||||
metrics = api.get_more_info(arch_index, "cifar10-valid", None, False, False)
|
||||
gt_accs_10_valid.append(metrics["valid-accuracy"])
|
||||
metrics = api.get_more_info(arch_index, "cifar10", None, False, False)
|
||||
gt_accs_10_test.append(metrics["test-accuracy"])
|
||||
select_logits = []
|
||||
for i, node_info in enumerate(arch.nodes):
|
||||
for op, xin in node_info:
|
||||
node_str = "{:}<-{:}".format(i + 1, xin)
|
||||
op_index = model.op_names.index(op)
|
||||
select_logits.append(logits[model.edge2index[node_str], op_index])
|
||||
cur_prob = sum(select_logits).item()
|
||||
probs.append(cur_prob)
|
||||
cor_prob_valid = np.corrcoef(probs, gt_accs_10_valid)[0, 1]
|
||||
cor_prob_test = np.corrcoef(probs, gt_accs_10_test)[0, 1]
|
||||
print(
|
||||
"{:} correlation for probabilities : {:.6f} on CIFAR-10 validation and {:.6f} on CIFAR-10 test".format(
|
||||
time_string(), cor_prob_valid, cor_prob_test
|
||||
)
|
||||
)
|
||||
|
||||
for idx, arch in enumerate(archs):
|
||||
model.set_cal_mode("dynamic", arch)
|
||||
try:
|
||||
inputs, targets = next(loader_iter)
|
||||
except:
|
||||
loader_iter = iter(xloader)
|
||||
inputs, targets = next(loader_iter)
|
||||
_, logits = model(inputs.cuda())
|
||||
_, preds = torch.max(logits, dim=-1)
|
||||
correct = (preds == targets.cuda()).float()
|
||||
accuracies.append(correct.mean().item())
|
||||
if idx != 0 and (idx % 500 == 0 or idx + 1 == len(archs)):
|
||||
cor_accs_valid = np.corrcoef(accuracies, gt_accs_10_valid[: idx + 1])[
|
||||
0, 1
|
||||
]
|
||||
cor_accs_test = np.corrcoef(accuracies, gt_accs_10_test[: idx + 1])[
|
||||
0, 1
|
||||
]
|
||||
print(
|
||||
"{:} {:05d}/{:05d} mode={:5s}, correlation : accs={:.5f} for CIFAR-10 valid, {:.5f} for CIFAR-10 test.".format(
|
||||
time_string(),
|
||||
idx,
|
||||
len(archs),
|
||||
"Train" if cal_mode else "Eval",
|
||||
cor_accs_valid,
|
||||
cor_accs_test,
|
||||
)
|
||||
)
|
||||
model.load_state_dict(weights)
|
||||
return archs, probs, accuracies
|
||||
129
AutoDL-Projects/xautodl/utils/qlib_utils.py
Normal file
129
AutoDL-Projects/xautodl/utils/qlib_utils.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from typing import List, Text
|
||||
from collections import defaultdict, OrderedDict
|
||||
|
||||
|
||||
class QResult:
|
||||
"""A class to maintain the results of a qlib experiment."""
|
||||
|
||||
def __init__(self, name):
|
||||
self._result = defaultdict(list)
|
||||
self._name = name
|
||||
self._recorder_paths = []
|
||||
self._date2ICs = []
|
||||
|
||||
def append(self, key, value):
|
||||
self._result[key].append(value)
|
||||
|
||||
def append_path(self, xpath):
|
||||
self._recorder_paths.append(xpath)
|
||||
|
||||
def append_date2ICs(self, date2IC):
|
||||
if self._date2ICs: # not empty
|
||||
keys = sorted(list(date2IC.keys()))
|
||||
pre_keys = sorted(list(self._date2ICs[0].keys()))
|
||||
assert len(keys) == len(pre_keys)
|
||||
for i, (x, y) in enumerate(zip(keys, pre_keys)):
|
||||
assert x == y, "[{:}] {:} vs {:}".format(i, x, y)
|
||||
self._date2ICs.append(date2IC)
|
||||
|
||||
def find_all_dates(self):
|
||||
dates = self._date2ICs[-1].keys()
|
||||
return sorted(list(dates))
|
||||
|
||||
def get_IC_by_date(self, date, scale=1.0):
|
||||
values = []
|
||||
for date2IC in self._date2ICs:
|
||||
values.append(date2IC[date] * scale)
|
||||
return float(np.mean(values)), float(np.std(values))
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def paths(self):
|
||||
return self._recorder_paths
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
return self._result
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
return list(self._result.keys())
|
||||
|
||||
def __len__(self):
|
||||
return len(self._result)
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({xname}, {num} metrics)".format(
|
||||
name=self.__class__.__name__, xname=self.name, num=len(self.result)
|
||||
)
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key not in self._result:
|
||||
raise ValueError(
|
||||
"Invalid key {:}, please use one of {:}".format(key, self.keys)
|
||||
)
|
||||
values = self._result[key]
|
||||
return float(np.mean(values))
|
||||
|
||||
def update(self, metrics, filter_keys=None):
|
||||
for key, value in metrics.items():
|
||||
if filter_keys is not None and key in filter_keys:
|
||||
key = filter_keys[key]
|
||||
elif filter_keys is not None:
|
||||
continue
|
||||
self.append(key, value)
|
||||
|
||||
@staticmethod
|
||||
def full_str(xstr, space):
|
||||
xformat = "{:" + str(space) + "s}"
|
||||
return xformat.format(str(xstr))
|
||||
|
||||
@staticmethod
|
||||
def merge_dict(dict_list):
|
||||
new_dict = dict()
|
||||
for xkey in dict_list[0].keys():
|
||||
values = [x for xdict in dict_list for x in xdict[xkey]]
|
||||
new_dict[xkey] = values
|
||||
return new_dict
|
||||
|
||||
def info(
|
||||
self,
|
||||
keys: List[Text],
|
||||
separate: Text = "& ",
|
||||
space: int = 20,
|
||||
verbose: bool = True,
|
||||
version: str = "v1",
|
||||
):
|
||||
avaliable_keys = []
|
||||
for key in keys:
|
||||
if key not in self.result:
|
||||
print("There are invalid key [{:}].".format(key))
|
||||
else:
|
||||
avaliable_keys.append(key)
|
||||
head_str = separate.join([self.full_str(x, space) for x in avaliable_keys])
|
||||
values = []
|
||||
for key in avaliable_keys:
|
||||
if "IR" in key:
|
||||
current_values = [x * 100 for x in self._result[key]]
|
||||
else:
|
||||
current_values = self._result[key]
|
||||
mean = np.mean(current_values)
|
||||
std = np.std(current_values)
|
||||
if version == "v0":
|
||||
values.append("{:.2f} $\pm$ {:.2f}".format(mean, std))
|
||||
elif version == "v1":
|
||||
values.append(
|
||||
"{:.2f}".format(mean) + " \\subs{" + "{:.2f}".format(std) + "}"
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown version")
|
||||
value_str = separate.join([self.full_str(x, space) for x in values])
|
||||
if verbose:
|
||||
print(head_str)
|
||||
print(value_str)
|
||||
return head_str, value_str
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user