Prototype MAML
This commit is contained in:
		| @@ -23,6 +23,9 @@ from datasets.synthetic_core import get_synthetic_env | |||||||
| from models.xcore import get_model | from models.xcore import get_model | ||||||
|  |  | ||||||
|  |  | ||||||
|  | from lfna_utils import lfna_setup | ||||||
|  |  | ||||||
|  |  | ||||||
| def subsample(historical_x, historical_y, maxn=10000): | def subsample(historical_x, historical_y, maxn=10000): | ||||||
|     total = historical_x.size(0) |     total = historical_x.size(0) | ||||||
|     if total <= maxn: |     if total <= maxn: | ||||||
| @@ -33,24 +36,7 @@ def subsample(historical_x, historical_y, maxn=10000): | |||||||
|  |  | ||||||
|  |  | ||||||
| def main(args): | def main(args): | ||||||
|     prepare_seed(args.rand_seed) |     logger, env_info = lfna_setup(args) | ||||||
|     logger = prepare_logger(args) |  | ||||||
|  |  | ||||||
|     cache_path = ( |  | ||||||
|         logger.path(None) / ".." / "env-{:}-info.pth".format(args.env_version) |  | ||||||
|     ).resolve() |  | ||||||
|     if cache_path.exists(): |  | ||||||
|         env_info = torch.load(cache_path) |  | ||||||
|     else: |  | ||||||
|         env_info = dict() |  | ||||||
|         dynamic_env = get_synthetic_env(version=args.env_version) |  | ||||||
|         env_info["total"] = len(dynamic_env) |  | ||||||
|         for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): |  | ||||||
|             env_info["{:}-timestamp".format(idx)] = timestamp |  | ||||||
|             env_info["{:}-x".format(idx)] = _allx |  | ||||||
|             env_info["{:}-y".format(idx)] = _ally |  | ||||||
|         env_info["dynamic_env"] = dynamic_env |  | ||||||
|         torch.save(env_info, cache_path) |  | ||||||
|  |  | ||||||
|     # check indexes to be evaluated |     # check indexes to be evaluated | ||||||
|     to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None) |     to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None) | ||||||
| @@ -60,6 +46,8 @@ def main(args): | |||||||
|         ) |         ) | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |     w_container_per_epoch = dict() | ||||||
|  |  | ||||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() |     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||||
|     for i, idx in enumerate(to_evaluate_indexes): |     for i, idx in enumerate(to_evaluate_indexes): | ||||||
|  |  | ||||||
| @@ -89,9 +77,6 @@ def main(args): | |||||||
|             output_dim=1, |             output_dim=1, | ||||||
|             act_cls="leaky_relu", |             act_cls="leaky_relu", | ||||||
|             norm_cls="identity", |             norm_cls="identity", | ||||||
|             # norm_cls="simple_norm", |  | ||||||
|             # mean=mean, |  | ||||||
|             # std=std, |  | ||||||
|         ) |         ) | ||||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) |         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||||
|         # build optimizer |         # build optimizer | ||||||
| @@ -144,6 +129,7 @@ def main(args): | |||||||
|         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( |         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( | ||||||
|             idx, env_info["total"] |             idx, env_info["total"] | ||||||
|         ) |         ) | ||||||
|  |         w_container_per_epoch[idx] = model.get_w_container().no_grad_clone() | ||||||
|         save_checkpoint( |         save_checkpoint( | ||||||
|             { |             { | ||||||
|                 "model_state_dict": model.state_dict(), |                 "model_state_dict": model.state_dict(), | ||||||
| @@ -155,10 +141,14 @@ def main(args): | |||||||
|             logger, |             logger, | ||||||
|         ) |         ) | ||||||
|         logger.log("") |         logger.log("") | ||||||
|  |  | ||||||
|         per_timestamp_time.update(time.time() - start_time) |         per_timestamp_time.update(time.time() - start_time) | ||||||
|         start_time = time.time() |         start_time = time.time() | ||||||
|  |  | ||||||
|  |     save_checkpoint( | ||||||
|  |         {"w_container_per_epoch": w_container_per_epoch}, | ||||||
|  |         logger.path(None) / "final-ckp.pth", | ||||||
|  |         logger, | ||||||
|  |     ) | ||||||
|     logger.log("-" * 200 + "\n") |     logger.log("-" * 200 + "\n") | ||||||
|     logger.close() |     logger.close() | ||||||
|  |  | ||||||
| @@ -210,5 +200,7 @@ if __name__ == "__main__": | |||||||
|     if args.rand_seed is None or args.rand_seed < 0: |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|         args.rand_seed = random.randint(1, 100000) |         args.rand_seed = random.randint(1, 100000) | ||||||
|     assert args.save_dir is not None, "The save dir argument can not be None" |     assert args.save_dir is not None, "The save dir argument can not be None" | ||||||
|     args.save_dir = "{:}-{:}".format(args.save_dir, args.env_version) |     args.save_dir = "{:}-{:}-d{:}".format( | ||||||
|  |         args.save_dir, args.env_version, args.hidden_dim | ||||||
|  |     ) | ||||||
|     main(args) |     main(args) | ||||||
|   | |||||||
							
								
								
									
										220
									
								
								exps/LFNA/basic-maml.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										220
									
								
								exps/LFNA/basic-maml.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,220 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
|  | ##################################################### | ||||||
|  | # python exps/LFNA/basic-maml.py --env_version v1   # | ||||||
|  | # python exps/LFNA/basic-maml.py --env_version v2   # | ||||||
|  | ##################################################### | ||||||
|  | import sys, time, copy, torch, random, argparse | ||||||
|  | from tqdm import tqdm | ||||||
|  | from copy import deepcopy | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||||
|  | if str(lib_dir) not in sys.path: | ||||||
|  |     sys.path.insert(0, str(lib_dir)) | ||||||
|  | from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||||
|  | from log_utils import time_string | ||||||
|  | from log_utils import AverageMeter, convert_secs2time | ||||||
|  |  | ||||||
|  | from utils import split_str2indexes | ||||||
|  |  | ||||||
|  | from procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||||
|  | from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||||
|  | from datasets.synthetic_core import get_synthetic_env | ||||||
|  | from models.xcore import get_model | ||||||
|  | from xlayers import super_core | ||||||
|  |  | ||||||
|  | from lfna_utils import lfna_setup, TimeData | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MAML: | ||||||
|  |     """A LFNA meta-model that uses the MLP as delta-net.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, container, criterion, meta_lr, inner_lr=0.01, inner_step=1): | ||||||
|  |         self.criterion = criterion | ||||||
|  |         self.container = container | ||||||
|  |         self.meta_optimizer = torch.optim.Adam( | ||||||
|  |             self.container.parameters(), lr=meta_lr, amsgrad=True | ||||||
|  |         ) | ||||||
|  |         self.inner_lr = inner_lr | ||||||
|  |         self.inner_step = inner_step | ||||||
|  |  | ||||||
|  |     def adapt(self, model, dataset): | ||||||
|  |         # create a container for the future timestamp | ||||||
|  |         y_hat = model.forward_with_container(dataset.x, self.container) | ||||||
|  |         loss = self.criterion(y_hat, dataset.y) | ||||||
|  |         grads = torch.autograd.grad(loss, self.container.parameters()) | ||||||
|  |  | ||||||
|  |         fast_container = self.container.additive( | ||||||
|  |             [-self.inner_lr * grad for grad in grads] | ||||||
|  |         ) | ||||||
|  |         import pdb | ||||||
|  |  | ||||||
|  |         pdb.set_trace() | ||||||
|  |         w_container.requires_grad_(True) | ||||||
|  |         containers = [w_container] | ||||||
|  |         for idx, dataset in enumerate(seq_datasets): | ||||||
|  |             x, y = dataset.x, dataset.y | ||||||
|  |             y_hat = model.forward_with_container(x, containers[-1]) | ||||||
|  |             loss = criterion(y_hat, y) | ||||||
|  |             gradients = torch.autograd.grad(loss, containers[-1].tensors) | ||||||
|  |             with torch.no_grad(): | ||||||
|  |                 flatten_w = containers[-1].flatten().view(-1, 1) | ||||||
|  |                 flatten_g = containers[-1].flatten(gradients).view(-1, 1) | ||||||
|  |                 input_statistics = torch.tensor([x.mean(), x.std()]).view(1, 2) | ||||||
|  |                 input_statistics = input_statistics.expand(flatten_w.numel(), -1) | ||||||
|  |             delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) | ||||||
|  |             delta = self.delta_net(delta_inputs).view(-1) | ||||||
|  |             delta = torch.clamp(delta, -0.5, 0.5) | ||||||
|  |             unflatten_delta = containers[-1].unflatten(delta) | ||||||
|  |             future_container = containers[-1].no_grad_clone().additive(unflatten_delta) | ||||||
|  |             # future_container = containers[-1].additive(unflatten_delta) | ||||||
|  |             containers.append(future_container) | ||||||
|  |         # containers = containers[1:] | ||||||
|  |         meta_loss = [] | ||||||
|  |         temp_containers = [] | ||||||
|  |         for idx, dataset in enumerate(seq_datasets): | ||||||
|  |             if idx == 0: | ||||||
|  |                 continue | ||||||
|  |             current_container = containers[idx] | ||||||
|  |             y_hat = model.forward_with_container(dataset.x, current_container) | ||||||
|  |             loss = criterion(y_hat, dataset.y) | ||||||
|  |             meta_loss.append(loss) | ||||||
|  |             temp_containers.append((dataset.timestamp, current_container, -loss.item())) | ||||||
|  |         meta_loss = sum(meta_loss) | ||||||
|  |         w_container.requires_grad_(False) | ||||||
|  |         # meta_loss.backward() | ||||||
|  |         # self.meta_optimizer.step() | ||||||
|  |         return meta_loss, temp_containers | ||||||
|  |  | ||||||
|  |     def step(self): | ||||||
|  |         torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0) | ||||||
|  |         self.meta_optimizer.step() | ||||||
|  |  | ||||||
|  |     def zero_grad(self): | ||||||
|  |         self.meta_optimizer.zero_grad() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(args): | ||||||
|  |     logger, env_info = lfna_setup(args) | ||||||
|  |  | ||||||
|  |     total_time = env_info["total"] | ||||||
|  |     for i in range(total_time): | ||||||
|  |         for xkey in ("timestamp", "x", "y"): | ||||||
|  |             nkey = "{:}-{:}".format(i, xkey) | ||||||
|  |             assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) | ||||||
|  |     train_time_bar = total_time // 2 | ||||||
|  |     base_model = get_model( | ||||||
|  |         dict(model_type="simple_mlp"), | ||||||
|  |         act_cls="leaky_relu", | ||||||
|  |         norm_cls="identity", | ||||||
|  |         input_dim=1, | ||||||
|  |         output_dim=1, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     w_container = base_model.get_w_container() | ||||||
|  |     criterion = torch.nn.MSELoss() | ||||||
|  |     print("There are {:} weights.".format(w_container.numel())) | ||||||
|  |  | ||||||
|  |     maml = MAML(w_container, criterion, args.meta_lr, args.inner_lr, args.inner_step) | ||||||
|  |  | ||||||
|  |     # meta-training | ||||||
|  |     per_epoch_time, start_time = AverageMeter(), time.time() | ||||||
|  |     for iepoch in range(args.epochs): | ||||||
|  |  | ||||||
|  |         need_time = "Time Left: {:}".format( | ||||||
|  |             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||||
|  |         ) | ||||||
|  |         logger.log( | ||||||
|  |             "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) | ||||||
|  |             + need_time | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         maml.zero_grad() | ||||||
|  |  | ||||||
|  |         all_meta_losses = [] | ||||||
|  |         for ibatch in range(args.meta_batch): | ||||||
|  |             sampled_timestamp = random.randint(0, train_time_bar) | ||||||
|  |             past_dataset = TimeData( | ||||||
|  |                 sampled_timestamp, | ||||||
|  |                 env_info["{:}-x".format(sampled_timestamp)], | ||||||
|  |                 env_info["{:}-y".format(sampled_timestamp)], | ||||||
|  |             ) | ||||||
|  |             future_dataset = TimeData( | ||||||
|  |                 sampled_timestamp + 1, | ||||||
|  |                 env_info["{:}-x".format(sampled_timestamp + 1)], | ||||||
|  |                 env_info["{:}-y".format(sampled_timestamp + 1)], | ||||||
|  |             ) | ||||||
|  |             maml.adapt(base_model, past_dataset) | ||||||
|  |             import pdb | ||||||
|  |  | ||||||
|  |             pdb.set_trace() | ||||||
|  |  | ||||||
|  |         meta_loss = torch.stack(all_meta_losses).mean() | ||||||
|  |         meta_loss.backward() | ||||||
|  |         adaptor.step() | ||||||
|  |  | ||||||
|  |         debug_str = pool.debug_info(debug_timestamp) | ||||||
|  |         logger.log("meta-loss: {:.4f}".format(meta_loss.item())) | ||||||
|  |  | ||||||
|  |         per_epoch_time.update(time.time() - start_time) | ||||||
|  |         start_time = time.time() | ||||||
|  |  | ||||||
|  |     logger.log("-" * 200 + "\n") | ||||||
|  |     logger.close() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = argparse.ArgumentParser("Use the data in the past.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--save_dir", | ||||||
|  |         type=str, | ||||||
|  |         default="./outputs/lfna-synthetic/maml", | ||||||
|  |         help="The checkpoint directory.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--env_version", | ||||||
|  |         type=str, | ||||||
|  |         required=True, | ||||||
|  |         help="The synthetic enviornment version.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--meta_lr", | ||||||
|  |         type=float, | ||||||
|  |         default=0.01, | ||||||
|  |         help="The learning rate for the MAML optimizer (default is Adam)", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--inner_lr", | ||||||
|  |         type=float, | ||||||
|  |         default=0.01, | ||||||
|  |         help="The learning rate for the inner optimization", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--inner_step", type=int, default=1, help="The inner loop steps for MAML." | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--meta_batch", | ||||||
|  |         type=int, | ||||||
|  |         default=5, | ||||||
|  |         help="The batch size for the meta-model", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--epochs", | ||||||
|  |         type=int, | ||||||
|  |         default=1000, | ||||||
|  |         help="The total number of epochs.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--workers", | ||||||
|  |         type=int, | ||||||
|  |         default=4, | ||||||
|  |         help="The number of data loading workers (default: 4)", | ||||||
|  |     ) | ||||||
|  |     # Random Seed | ||||||
|  |     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||||
|  |     args = parser.parse_args() | ||||||
|  |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|  |         args.rand_seed = random.randint(1, 100000) | ||||||
|  |     assert args.save_dir is not None, "The save dir argument can not be None" | ||||||
|  |     main(args) | ||||||
| @@ -1,7 +1,8 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
| ##################################################### | ##################################################### | ||||||
| # python exps/LFNA/basic-same.py --srange 1-999 | # python exps/LFNA/basic-same.py --srange 1-999 --env_version v1 --hidden_dim 16 | ||||||
|  | # python exps/LFNA/basic-same.py --srange 1-999 --env_version v2 --hidden_dim | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| @@ -22,6 +23,8 @@ from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | |||||||
| from datasets.synthetic_core import get_synthetic_env | from datasets.synthetic_core import get_synthetic_env | ||||||
| from models.xcore import get_model | from models.xcore import get_model | ||||||
|  |  | ||||||
|  | from lfna_utils import lfna_setup | ||||||
|  |  | ||||||
|  |  | ||||||
| def subsample(historical_x, historical_y, maxn=10000): | def subsample(historical_x, historical_y, maxn=10000): | ||||||
|     total = historical_x.size(0) |     total = historical_x.size(0) | ||||||
| @@ -33,22 +36,7 @@ def subsample(historical_x, historical_y, maxn=10000): | |||||||
|  |  | ||||||
|  |  | ||||||
| def main(args): | def main(args): | ||||||
|     prepare_seed(args.rand_seed) |     logger, env_info, model_kwargs = lfna_setup(args) | ||||||
|     logger = prepare_logger(args) |  | ||||||
|  |  | ||||||
|     cache_path = (logger.path(None) / ".." / "env-info.pth").resolve() |  | ||||||
|     if cache_path.exists(): |  | ||||||
|         env_info = torch.load(cache_path) |  | ||||||
|     else: |  | ||||||
|         env_info = dict() |  | ||||||
|         dynamic_env = get_synthetic_env() |  | ||||||
|         env_info["total"] = len(dynamic_env) |  | ||||||
|         for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): |  | ||||||
|             env_info["{:}-timestamp".format(idx)] = timestamp |  | ||||||
|             env_info["{:}-x".format(idx)] = _allx |  | ||||||
|             env_info["{:}-y".format(idx)] = _ally |  | ||||||
|         env_info["dynamic_env"] = dynamic_env |  | ||||||
|         torch.save(env_info, cache_path) |  | ||||||
|  |  | ||||||
|     # check indexes to be evaluated |     # check indexes to be evaluated | ||||||
|     to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None) |     to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None) | ||||||
| @@ -78,16 +66,6 @@ def main(args): | |||||||
|         historical_x = env_info["{:}-x".format(idx)] |         historical_x = env_info["{:}-x".format(idx)] | ||||||
|         historical_y = env_info["{:}-y".format(idx)] |         historical_y = env_info["{:}-y".format(idx)] | ||||||
|         # build model |         # build model | ||||||
|         mean, std = historical_x.mean().item(), historical_x.std().item() |  | ||||||
|         model_kwargs = dict( |  | ||||||
|             input_dim=1, |  | ||||||
|             output_dim=1, |  | ||||||
|             act_cls="leaky_relu", |  | ||||||
|             norm_cls="identity", |  | ||||||
|             # norm_cls="simple_norm", |  | ||||||
|             # mean=mean, |  | ||||||
|             # std=std, |  | ||||||
|         ) |  | ||||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) |         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||||
|         # build optimizer |         # build optimizer | ||||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) |         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||||
| @@ -151,9 +129,9 @@ def main(args): | |||||||
|             logger, |             logger, | ||||||
|         ) |         ) | ||||||
|         logger.log("") |         logger.log("") | ||||||
|  |  | ||||||
|         per_timestamp_time.update(time.time() - start_time) |         per_timestamp_time.update(time.time() - start_time) | ||||||
|         start_time = time.time() |         start_time = time.time() | ||||||
|  |  | ||||||
|     save_checkpoint( |     save_checkpoint( | ||||||
|         {"w_container_per_epoch": w_container_per_epoch}, |         {"w_container_per_epoch": w_container_per_epoch}, | ||||||
|         logger.path(None) / "final-ckp.pth", |         logger.path(None) / "final-ckp.pth", | ||||||
| @@ -172,6 +150,18 @@ if __name__ == "__main__": | |||||||
|         default="./outputs/lfna-synthetic/use-same-timestamp", |         default="./outputs/lfna-synthetic/use-same-timestamp", | ||||||
|         help="The checkpoint directory.", |         help="The checkpoint directory.", | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--env_version", | ||||||
|  |         type=str, | ||||||
|  |         required=True, | ||||||
|  |         help="The synthetic enviornment version.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--hidden_dim", | ||||||
|  |         type=int, | ||||||
|  |         required=True, | ||||||
|  |         help="The hidden dimension.", | ||||||
|  |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--init_lr", |         "--init_lr", | ||||||
|         type=float, |         type=float, | ||||||
| @@ -205,4 +195,7 @@ if __name__ == "__main__": | |||||||
|     if args.rand_seed is None or args.rand_seed < 0: |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|         args.rand_seed = random.randint(1, 100000) |         args.rand_seed = random.randint(1, 100000) | ||||||
|     assert args.save_dir is not None, "The save dir argument can not be None" |     assert args.save_dir is not None, "The save dir argument can not be None" | ||||||
|  |     args.save_dir = "{:}-{:}-d{:}".format( | ||||||
|  |         args.save_dir, args.env_version, args.hidden_dim | ||||||
|  |     ) | ||||||
|     main(args) |     main(args) | ||||||
|   | |||||||
							
								
								
									
										272
									
								
								exps/LFNA/lfna-v0.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										272
									
								
								exps/LFNA/lfna-v0.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,272 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
|  | ##################################################### | ||||||
|  | # python exps/LFNA/lfna-v0.py | ||||||
|  | ##################################################### | ||||||
|  | import sys, time, copy, torch, random, argparse | ||||||
|  | from tqdm import tqdm | ||||||
|  | from copy import deepcopy | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||||
|  | if str(lib_dir) not in sys.path: | ||||||
|  |     sys.path.insert(0, str(lib_dir)) | ||||||
|  | from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||||
|  | from log_utils import time_string | ||||||
|  | from log_utils import AverageMeter, convert_secs2time | ||||||
|  |  | ||||||
|  | from utils import split_str2indexes | ||||||
|  |  | ||||||
|  | from procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||||
|  | from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||||
|  | from datasets.synthetic_core import get_synthetic_env | ||||||
|  | from models.xcore import get_model | ||||||
|  | from xlayers import super_core | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class LFNAmlp: | ||||||
|  |     """A LFNA meta-model that uses the MLP as delta-net.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, obs_dim, hidden_sizes, act_name): | ||||||
|  |         self.delta_net = super_core.SuperSequential( | ||||||
|  |             super_core.SuperLinear(obs_dim, hidden_sizes[0]), | ||||||
|  |             super_core.super_name2activation[act_name](), | ||||||
|  |             super_core.SuperLinear(hidden_sizes[0], hidden_sizes[1]), | ||||||
|  |             super_core.super_name2activation[act_name](), | ||||||
|  |             super_core.SuperLinear(hidden_sizes[1], 1), | ||||||
|  |         ) | ||||||
|  |         self.meta_optimizer = torch.optim.Adam( | ||||||
|  |             self.delta_net.parameters(), lr=0.01, amsgrad=True | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def adapt(self, model, criterion, w_container, seq_datasets): | ||||||
|  |         w_container.requires_grad_(True) | ||||||
|  |         containers = [w_container] | ||||||
|  |         for idx, dataset in enumerate(seq_datasets): | ||||||
|  |             x, y = dataset.x, dataset.y | ||||||
|  |             y_hat = model.forward_with_container(x, containers[-1]) | ||||||
|  |             loss = criterion(y_hat, y) | ||||||
|  |             gradients = torch.autograd.grad(loss, containers[-1].tensors) | ||||||
|  |             with torch.no_grad(): | ||||||
|  |                 flatten_w = containers[-1].flatten().view(-1, 1) | ||||||
|  |                 flatten_g = containers[-1].flatten(gradients).view(-1, 1) | ||||||
|  |                 input_statistics = torch.tensor([x.mean(), x.std()]).view(1, 2) | ||||||
|  |                 input_statistics = input_statistics.expand(flatten_w.numel(), -1) | ||||||
|  |             delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) | ||||||
|  |             delta = self.delta_net(delta_inputs).view(-1) | ||||||
|  |             delta = torch.clamp(delta, -0.5, 0.5) | ||||||
|  |             unflatten_delta = containers[-1].unflatten(delta) | ||||||
|  |             future_container = containers[-1].no_grad_clone().additive(unflatten_delta) | ||||||
|  |             # future_container = containers[-1].additive(unflatten_delta) | ||||||
|  |             containers.append(future_container) | ||||||
|  |         # containers = containers[1:] | ||||||
|  |         meta_loss = [] | ||||||
|  |         temp_containers = [] | ||||||
|  |         for idx, dataset in enumerate(seq_datasets): | ||||||
|  |             if idx == 0: | ||||||
|  |                 continue | ||||||
|  |             current_container = containers[idx] | ||||||
|  |             y_hat = model.forward_with_container(dataset.x, current_container) | ||||||
|  |             loss = criterion(y_hat, dataset.y) | ||||||
|  |             meta_loss.append(loss) | ||||||
|  |             temp_containers.append((dataset.timestamp, current_container, -loss.item())) | ||||||
|  |         meta_loss = sum(meta_loss) | ||||||
|  |         w_container.requires_grad_(False) | ||||||
|  |         # meta_loss.backward() | ||||||
|  |         # self.meta_optimizer.step() | ||||||
|  |         return meta_loss, temp_containers | ||||||
|  |  | ||||||
|  |     def step(self): | ||||||
|  |         torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0) | ||||||
|  |         self.meta_optimizer.step() | ||||||
|  |  | ||||||
|  |     def zero_grad(self): | ||||||
|  |         self.meta_optimizer.zero_grad() | ||||||
|  |         self.delta_net.zero_grad() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TimeData: | ||||||
|  |     def __init__(self, timestamp, xs, ys): | ||||||
|  |         self._timestamp = timestamp | ||||||
|  |         self._xs = xs | ||||||
|  |         self._ys = ys | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def x(self): | ||||||
|  |         return self._xs | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def y(self): | ||||||
|  |         return self._ys | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def timestamp(self): | ||||||
|  |         return self._timestamp | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Population: | ||||||
|  |     """A population used to maintain models at different timestamps.""" | ||||||
|  |  | ||||||
|  |     def __init__(self): | ||||||
|  |         self._time2model = dict() | ||||||
|  |         self._time2score = dict()  # higher is better | ||||||
|  |  | ||||||
|  |     def append(self, timestamp, model, score): | ||||||
|  |         if timestamp in self._time2model: | ||||||
|  |             if self._time2score[timestamp] > score: | ||||||
|  |                 return | ||||||
|  |         self._time2model[timestamp] = model.no_grad_clone() | ||||||
|  |         self._time2score[timestamp] = score | ||||||
|  |  | ||||||
|  |     def query(self, timestamp): | ||||||
|  |         closet_timestamp = None | ||||||
|  |         for xtime, model in self._time2model.items(): | ||||||
|  |             if closet_timestamp is None or ( | ||||||
|  |                 xtime < timestamp and timestamp - closet_timestamp >= timestamp - xtime | ||||||
|  |             ): | ||||||
|  |                 closet_timestamp = xtime | ||||||
|  |         return self._time2model[closet_timestamp], closet_timestamp | ||||||
|  |  | ||||||
|  |     def debug_info(self, timestamps): | ||||||
|  |         xstrs = [] | ||||||
|  |         for timestamp in timestamps: | ||||||
|  |             if timestamp in self._time2score: | ||||||
|  |                 xstrs.append( | ||||||
|  |                     "{:04d}: {:.4f}".format(timestamp, self._time2score[timestamp]) | ||||||
|  |                 ) | ||||||
|  |         return ", ".join(xstrs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(args): | ||||||
|  |     prepare_seed(args.rand_seed) | ||||||
|  |     logger = prepare_logger(args) | ||||||
|  |  | ||||||
|  |     cache_path = (logger.path(None) / ".." / "env-info.pth").resolve() | ||||||
|  |     if cache_path.exists(): | ||||||
|  |         env_info = torch.load(cache_path) | ||||||
|  |     else: | ||||||
|  |         env_info = dict() | ||||||
|  |         dynamic_env = get_synthetic_env() | ||||||
|  |         env_info["total"] = len(dynamic_env) | ||||||
|  |         for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): | ||||||
|  |             env_info["{:}-timestamp".format(idx)] = timestamp | ||||||
|  |             env_info["{:}-x".format(idx)] = _allx | ||||||
|  |             env_info["{:}-y".format(idx)] = _ally | ||||||
|  |         env_info["dynamic_env"] = dynamic_env | ||||||
|  |         torch.save(env_info, cache_path) | ||||||
|  |  | ||||||
|  |     total_time = env_info["total"] | ||||||
|  |     for i in range(total_time): | ||||||
|  |         for xkey in ("timestamp", "x", "y"): | ||||||
|  |             nkey = "{:}-{:}".format(i, xkey) | ||||||
|  |             assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) | ||||||
|  |     train_time_bar = total_time // 2 | ||||||
|  |     base_model = get_model( | ||||||
|  |         dict(model_type="simple_mlp"), | ||||||
|  |         act_cls="leaky_relu", | ||||||
|  |         norm_cls="identity", | ||||||
|  |         input_dim=1, | ||||||
|  |         output_dim=1, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     w_container = base_model.get_w_container() | ||||||
|  |     criterion = torch.nn.MSELoss() | ||||||
|  |     print("There are {:} weights.".format(w_container.numel())) | ||||||
|  |  | ||||||
|  |     adaptor = LFNAmlp(4, (50, 20), "leaky_relu") | ||||||
|  |  | ||||||
|  |     pool = Population() | ||||||
|  |     pool.append(0, w_container, -100) | ||||||
|  |  | ||||||
|  |     # LFNA meta-training | ||||||
|  |     per_epoch_time, start_time = AverageMeter(), time.time() | ||||||
|  |     for iepoch in range(args.epochs): | ||||||
|  |  | ||||||
|  |         need_time = "Time Left: {:}".format( | ||||||
|  |             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||||
|  |         ) | ||||||
|  |         logger.log( | ||||||
|  |             "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) | ||||||
|  |             + need_time | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         adaptor.zero_grad() | ||||||
|  |  | ||||||
|  |         debug_timestamp = set() | ||||||
|  |         all_meta_losses = [] | ||||||
|  |         for ibatch in range(args.meta_batch): | ||||||
|  |             sampled_timestamp = random.randint(0, train_time_bar) | ||||||
|  |             query_w_container, query_timestamp = pool.query(sampled_timestamp) | ||||||
|  |             # def adapt(self, model, w_container, xs, ys): | ||||||
|  |             seq_datasets = [] | ||||||
|  |             # xs, ys = [], [] | ||||||
|  |             for it in range(sampled_timestamp, sampled_timestamp + args.max_seq): | ||||||
|  |                 xs = env_info["{:}-x".format(it)] | ||||||
|  |                 ys = env_info["{:}-y".format(it)] | ||||||
|  |                 seq_datasets.append(TimeData(it, xs, ys)) | ||||||
|  |             temp_meta_loss, temp_containers = adaptor.adapt( | ||||||
|  |                 base_model, criterion, query_w_container, seq_datasets | ||||||
|  |             ) | ||||||
|  |             all_meta_losses.append(temp_meta_loss) | ||||||
|  |             for temp_time, temp_container, temp_score in temp_containers: | ||||||
|  |                 pool.append(temp_time, temp_container, temp_score) | ||||||
|  |                 debug_timestamp.add(temp_time) | ||||||
|  |         meta_loss = torch.stack(all_meta_losses).mean() | ||||||
|  |         meta_loss.backward() | ||||||
|  |         adaptor.step() | ||||||
|  |  | ||||||
|  |         debug_str = pool.debug_info(debug_timestamp) | ||||||
|  |         logger.log("meta-loss: {:.4f}".format(meta_loss.item())) | ||||||
|  |  | ||||||
|  |         per_epoch_time.update(time.time() - start_time) | ||||||
|  |         start_time = time.time() | ||||||
|  |  | ||||||
|  |     logger.log("-" * 200 + "\n") | ||||||
|  |     logger.close() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = argparse.ArgumentParser("Use the data in the past.") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--save_dir", | ||||||
|  |         type=str, | ||||||
|  |         default="./outputs/lfna-synthetic/lfna-v1", | ||||||
|  |         help="The checkpoint directory.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--init_lr", | ||||||
|  |         type=float, | ||||||
|  |         default=0.1, | ||||||
|  |         help="The initial learning rate for the optimizer (default is Adam)", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--meta_batch", | ||||||
|  |         type=int, | ||||||
|  |         default=5, | ||||||
|  |         help="The batch size for the meta-model", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--epochs", | ||||||
|  |         type=int, | ||||||
|  |         default=1000, | ||||||
|  |         help="The total number of epochs.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--max_seq", | ||||||
|  |         type=int, | ||||||
|  |         default=5, | ||||||
|  |         help="The maximum length of the sequence.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--workers", | ||||||
|  |         type=int, | ||||||
|  |         default=4, | ||||||
|  |         help="The number of data loading workers (default: 4)", | ||||||
|  |     ) | ||||||
|  |     # Random Seed | ||||||
|  |     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||||
|  |     args = parser.parse_args() | ||||||
|  |     if args.rand_seed is None or args.rand_seed < 0: | ||||||
|  |         args.rand_seed = random.randint(1, 100000) | ||||||
|  |     assert args.save_dir is not None, "The save dir argument can not be None" | ||||||
|  |     main(args) | ||||||
							
								
								
									
										61
									
								
								exps/LFNA/lfna_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								exps/LFNA/lfna_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||||
|  | ##################################################### | ||||||
|  | import torch | ||||||
|  | from tqdm import tqdm | ||||||
|  | from procedures import prepare_seed, prepare_logger | ||||||
|  | from datasets.synthetic_core import get_synthetic_env | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def lfna_setup(args): | ||||||
|  |     prepare_seed(args.rand_seed) | ||||||
|  |     logger = prepare_logger(args) | ||||||
|  |  | ||||||
|  |     cache_path = ( | ||||||
|  |         logger.path(None) / ".." / "env-{:}-info.pth".format(args.env_version) | ||||||
|  |     ).resolve() | ||||||
|  |     if cache_path.exists(): | ||||||
|  |         env_info = torch.load(cache_path) | ||||||
|  |     else: | ||||||
|  |         env_info = dict() | ||||||
|  |         dynamic_env = get_synthetic_env(version=args.env_version) | ||||||
|  |         env_info["total"] = len(dynamic_env) | ||||||
|  |         for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): | ||||||
|  |             env_info["{:}-timestamp".format(idx)] = timestamp | ||||||
|  |             env_info["{:}-x".format(idx)] = _allx | ||||||
|  |             env_info["{:}-y".format(idx)] = _ally | ||||||
|  |         env_info["dynamic_env"] = dynamic_env | ||||||
|  |         torch.save(env_info, cache_path) | ||||||
|  |  | ||||||
|  |     model_kwargs = dict( | ||||||
|  |         input_dim=1, | ||||||
|  |         output_dim=1, | ||||||
|  |         hidden_dim=args.hidden_dim, | ||||||
|  |         act_cls="leaky_relu", | ||||||
|  |         norm_cls="identity", | ||||||
|  |     ) | ||||||
|  |     return logger, env_info, model_kwargs | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TimeData: | ||||||
|  |     def __init__(self, timestamp, xs, ys): | ||||||
|  |         self._timestamp = timestamp | ||||||
|  |         self._xs = xs | ||||||
|  |         self._ys = ys | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def x(self): | ||||||
|  |         return self._xs | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def y(self): | ||||||
|  |         return self._ys | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def timestamp(self): | ||||||
|  |         return self._timestamp | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}(timestamp={:}, with {num} samples)".format( | ||||||
|  |             name=self.__class__.__name__, timestamp=self._timestamp, num=len(self._xs) | ||||||
|  |         ) | ||||||
| @@ -8,98 +8,110 @@ from .initialization import initialize_resnet | |||||||
|  |  | ||||||
|  |  | ||||||
| class Bottleneck(nn.Module): | class Bottleneck(nn.Module): | ||||||
|   def __init__(self, nChannels, growthRate): |     def __init__(self, nChannels, growthRate): | ||||||
|     super(Bottleneck, self).__init__() |         super(Bottleneck, self).__init__() | ||||||
|     interChannels = 4*growthRate |         interChannels = 4 * growthRate | ||||||
|     self.bn1 = nn.BatchNorm2d(nChannels) |         self.bn1 = nn.BatchNorm2d(nChannels) | ||||||
|     self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) |         self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) | ||||||
|     self.bn2 = nn.BatchNorm2d(interChannels) |         self.bn2 = nn.BatchNorm2d(interChannels) | ||||||
|     self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False) |         self.conv2 = nn.Conv2d( | ||||||
|  |             interChannels, growthRate, kernel_size=3, padding=1, bias=False | ||||||
|  |         ) | ||||||
|  |  | ||||||
|   def forward(self, x): |     def forward(self, x): | ||||||
|     out = self.conv1(F.relu(self.bn1(x))) |         out = self.conv1(F.relu(self.bn1(x))) | ||||||
|     out = self.conv2(F.relu(self.bn2(out))) |         out = self.conv2(F.relu(self.bn2(out))) | ||||||
|     out = torch.cat((x, out), 1) |         out = torch.cat((x, out), 1) | ||||||
|     return out |         return out | ||||||
|  |  | ||||||
|  |  | ||||||
| class SingleLayer(nn.Module): | class SingleLayer(nn.Module): | ||||||
|   def __init__(self, nChannels, growthRate): |     def __init__(self, nChannels, growthRate): | ||||||
|     super(SingleLayer, self).__init__() |         super(SingleLayer, self).__init__() | ||||||
|     self.bn1 = nn.BatchNorm2d(nChannels) |         self.bn1 = nn.BatchNorm2d(nChannels) | ||||||
|     self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False) |         self.conv1 = nn.Conv2d( | ||||||
|  |             nChannels, growthRate, kernel_size=3, padding=1, bias=False | ||||||
|  |         ) | ||||||
|  |  | ||||||
|   def forward(self, x): |     def forward(self, x): | ||||||
|     out = self.conv1(F.relu(self.bn1(x))) |         out = self.conv1(F.relu(self.bn1(x))) | ||||||
|     out = torch.cat((x, out), 1) |         out = torch.cat((x, out), 1) | ||||||
|     return out |         return out | ||||||
|  |  | ||||||
|  |  | ||||||
| class Transition(nn.Module): | class Transition(nn.Module): | ||||||
|   def __init__(self, nChannels, nOutChannels): |     def __init__(self, nChannels, nOutChannels): | ||||||
|     super(Transition, self).__init__() |         super(Transition, self).__init__() | ||||||
|     self.bn1 = nn.BatchNorm2d(nChannels) |         self.bn1 = nn.BatchNorm2d(nChannels) | ||||||
|     self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False) |         self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False) | ||||||
|  |  | ||||||
|   def forward(self, x): |     def forward(self, x): | ||||||
|     out = self.conv1(F.relu(self.bn1(x))) |         out = self.conv1(F.relu(self.bn1(x))) | ||||||
|     out = F.avg_pool2d(out, 2) |         out = F.avg_pool2d(out, 2) | ||||||
|     return out |         return out | ||||||
|  |  | ||||||
|  |  | ||||||
| class DenseNet(nn.Module): | class DenseNet(nn.Module): | ||||||
|   def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): |     def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): | ||||||
|     super(DenseNet, self).__init__() |         super(DenseNet, self).__init__() | ||||||
|  |  | ||||||
|     if bottleneck:  nDenseBlocks = int( (depth-4) / 6 ) |         if bottleneck: | ||||||
|     else         :  nDenseBlocks = int( (depth-4) / 3 ) |             nDenseBlocks = int((depth - 4) / 6) | ||||||
|  |         else: | ||||||
|  |             nDenseBlocks = int((depth - 4) / 3) | ||||||
|  |  | ||||||
|     self.message = 'CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}'.format('bottleneck' if bottleneck else 'basic', depth, reduction, growthRate, nClasses) |         self.message = "CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}".format( | ||||||
|  |             "bottleneck" if bottleneck else "basic", | ||||||
|  |             depth, | ||||||
|  |             reduction, | ||||||
|  |             growthRate, | ||||||
|  |             nClasses, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     nChannels = 2*growthRate |         nChannels = 2 * growthRate | ||||||
|     self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False) |         self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False) | ||||||
|  |  | ||||||
|     self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) |         self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||||
|     nChannels += nDenseBlocks*growthRate |         nChannels += nDenseBlocks * growthRate | ||||||
|     nOutChannels = int(math.floor(nChannels*reduction)) |         nOutChannels = int(math.floor(nChannels * reduction)) | ||||||
|     self.trans1 = Transition(nChannels, nOutChannels) |         self.trans1 = Transition(nChannels, nOutChannels) | ||||||
|  |  | ||||||
|     nChannels = nOutChannels |         nChannels = nOutChannels | ||||||
|     self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) |         self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||||
|     nChannels += nDenseBlocks*growthRate |         nChannels += nDenseBlocks * growthRate | ||||||
|     nOutChannels = int(math.floor(nChannels*reduction)) |         nOutChannels = int(math.floor(nChannels * reduction)) | ||||||
|     self.trans2 = Transition(nChannels, nOutChannels) |         self.trans2 = Transition(nChannels, nOutChannels) | ||||||
|  |  | ||||||
|     nChannels = nOutChannels |         nChannels = nOutChannels | ||||||
|     self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) |         self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) | ||||||
|     nChannels += nDenseBlocks*growthRate |         nChannels += nDenseBlocks * growthRate | ||||||
|  |  | ||||||
|     self.act = nn.Sequential( |         self.act = nn.Sequential( | ||||||
|                   nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), |             nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), nn.AvgPool2d(8) | ||||||
|                   nn.AvgPool2d(8)) |         ) | ||||||
|     self.fc  = nn.Linear(nChannels, nClasses) |         self.fc = nn.Linear(nChannels, nClasses) | ||||||
|  |  | ||||||
|     self.apply(initialize_resnet) |         self.apply(initialize_resnet) | ||||||
|  |  | ||||||
|   def get_message(self): |     def get_message(self): | ||||||
|     return self.message |         return self.message | ||||||
|  |  | ||||||
|   def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): |     def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): | ||||||
|     layers = [] |         layers = [] | ||||||
|     for i in range(int(nDenseBlocks)): |         for i in range(int(nDenseBlocks)): | ||||||
|       if bottleneck: |             if bottleneck: | ||||||
|         layers.append(Bottleneck(nChannels, growthRate)) |                 layers.append(Bottleneck(nChannels, growthRate)) | ||||||
|       else: |             else: | ||||||
|         layers.append(SingleLayer(nChannels, growthRate)) |                 layers.append(SingleLayer(nChannels, growthRate)) | ||||||
|       nChannels += growthRate |             nChannels += growthRate | ||||||
|     return nn.Sequential(*layers) |         return nn.Sequential(*layers) | ||||||
|  |  | ||||||
|   def forward(self, inputs): |     def forward(self, inputs): | ||||||
|     out = self.conv1( inputs ) |         out = self.conv1(inputs) | ||||||
|     out = self.trans1(self.dense1(out)) |         out = self.trans1(self.dense1(out)) | ||||||
|     out = self.trans2(self.dense2(out)) |         out = self.trans2(self.dense2(out)) | ||||||
|     out = self.dense3(out) |         out = self.dense3(out) | ||||||
|     features = self.act(out) |         features = self.act(out) | ||||||
|     features = features.view(features.size(0), -1) |         features = features.view(features.size(0), -1) | ||||||
|     out = self.fc(features) |         out = self.fc(features) | ||||||
|     return features, out |         return features, out | ||||||
|   | |||||||
| @@ -2,156 +2,179 @@ import torch | |||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
| from .initialization import initialize_resnet | from .initialization import initialize_resnet | ||||||
| from .SharedUtils    import additive_func | from .SharedUtils import additive_func | ||||||
|  |  | ||||||
|  |  | ||||||
| class Downsample(nn.Module): | class Downsample(nn.Module): | ||||||
|  |     def __init__(self, nIn, nOut, stride): | ||||||
|  |         super(Downsample, self).__init__() | ||||||
|  |         assert stride == 2 and nOut == 2 * nIn, "stride:{} IO:{},{}".format( | ||||||
|  |             stride, nIn, nOut | ||||||
|  |         ) | ||||||
|  |         self.in_dim = nIn | ||||||
|  |         self.out_dim = nOut | ||||||
|  |         self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||||
|  |         self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=1, padding=0, bias=False) | ||||||
|  |  | ||||||
|   def __init__(self, nIn, nOut, stride): |     def forward(self, x): | ||||||
|     super(Downsample, self).__init__()  |         x = self.avg(x) | ||||||
|     assert stride == 2 and nOut == 2*nIn, 'stride:{} IO:{},{}'.format(stride, nIn, nOut) |         out = self.conv(x) | ||||||
|     self.in_dim  = nIn |         return out | ||||||
|     self.out_dim = nOut |  | ||||||
|     self.avg  = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)    |  | ||||||
|     self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=1, padding=0, bias=False) |  | ||||||
|  |  | ||||||
|   def forward(self, x): |  | ||||||
|     x   = self.avg(x) |  | ||||||
|     out = self.conv(x) |  | ||||||
|     return out |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ConvBNReLU(nn.Module): | class ConvBNReLU(nn.Module): | ||||||
|  |     def __init__(self, nIn, nOut, kernel, stride, padding, bias, relu): | ||||||
|  |         super(ConvBNReLU, self).__init__() | ||||||
|  |         self.conv = nn.Conv2d( | ||||||
|  |             nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, bias=bias | ||||||
|  |         ) | ||||||
|  |         self.bn = nn.BatchNorm2d(nOut) | ||||||
|  |         if relu: | ||||||
|  |             self.relu = nn.ReLU(inplace=True) | ||||||
|  |         else: | ||||||
|  |             self.relu = None | ||||||
|  |         self.out_dim = nOut | ||||||
|  |         self.num_conv = 1 | ||||||
|  |  | ||||||
|   def __init__(self, nIn, nOut, kernel, stride, padding, bias, relu): |     def forward(self, x): | ||||||
|     super(ConvBNReLU, self).__init__() |         conv = self.conv(x) | ||||||
|     self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, bias=bias) |         bn = self.bn(conv) | ||||||
|     self.bn   = nn.BatchNorm2d(nOut) |         if self.relu: | ||||||
|     if relu: self.relu = nn.ReLU(inplace=True) |             return self.relu(bn) | ||||||
|     else   : self.relu = None |         else: | ||||||
|     self.out_dim = nOut |             return bn | ||||||
|     self.num_conv = 1 |  | ||||||
|  |  | ||||||
|   def forward(self, x): |  | ||||||
|     conv = self.conv( x ) |  | ||||||
|     bn   = self.bn( conv ) |  | ||||||
|     if self.relu: return self.relu( bn ) |  | ||||||
|     else        : return bn |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ResNetBasicblock(nn.Module): | class ResNetBasicblock(nn.Module): | ||||||
|   expansion = 1 |     expansion = 1 | ||||||
|   def __init__(self, inplanes, planes, stride): |  | ||||||
|     super(ResNetBasicblock, self).__init__() |  | ||||||
|     assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) |  | ||||||
|     self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, True) |  | ||||||
|     self.conv_b = ConvBNReLU(  planes, planes, 3,      1, 1, False, False) |  | ||||||
|     if stride == 2: |  | ||||||
|       self.downsample = Downsample(inplanes, planes, stride) |  | ||||||
|     elif inplanes != planes: |  | ||||||
|       self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, False) |  | ||||||
|     else: |  | ||||||
|       self.downsample = None |  | ||||||
|     self.out_dim = planes |  | ||||||
|     self.num_conv = 2 |  | ||||||
|  |  | ||||||
|   def forward(self, inputs): |     def __init__(self, inplanes, planes, stride): | ||||||
|  |         super(ResNetBasicblock, self).__init__() | ||||||
|  |         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||||
|  |         self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, True) | ||||||
|  |         self.conv_b = ConvBNReLU(planes, planes, 3, 1, 1, False, False) | ||||||
|  |         if stride == 2: | ||||||
|  |             self.downsample = Downsample(inplanes, planes, stride) | ||||||
|  |         elif inplanes != planes: | ||||||
|  |             self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, False) | ||||||
|  |         else: | ||||||
|  |             self.downsample = None | ||||||
|  |         self.out_dim = planes | ||||||
|  |         self.num_conv = 2 | ||||||
|  |  | ||||||
|     basicblock = self.conv_a(inputs) |     def forward(self, inputs): | ||||||
|     basicblock = self.conv_b(basicblock) |  | ||||||
|  |  | ||||||
|     if self.downsample is not None: |         basicblock = self.conv_a(inputs) | ||||||
|       residual = self.downsample(inputs) |         basicblock = self.conv_b(basicblock) | ||||||
|     else: |  | ||||||
|       residual = inputs |  | ||||||
|     out = additive_func(residual, basicblock) |  | ||||||
|     return F.relu(out, inplace=True) |  | ||||||
|  |  | ||||||
|  |         if self.downsample is not None: | ||||||
|  |             residual = self.downsample(inputs) | ||||||
|  |         else: | ||||||
|  |             residual = inputs | ||||||
|  |         out = additive_func(residual, basicblock) | ||||||
|  |         return F.relu(out, inplace=True) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ResNetBottleneck(nn.Module): | class ResNetBottleneck(nn.Module): | ||||||
|   expansion = 4 |     expansion = 4 | ||||||
|   def __init__(self, inplanes, planes, stride): |  | ||||||
|     super(ResNetBottleneck, self).__init__() |  | ||||||
|     assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) |  | ||||||
|     self.conv_1x1 = ConvBNReLU(inplanes, planes, 1,      1, 0, False, True) |  | ||||||
|     self.conv_3x3 = ConvBNReLU(  planes, planes, 3, stride, 1, False, True) |  | ||||||
|     self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, False) |  | ||||||
|     if stride == 2: |  | ||||||
|       self.downsample = Downsample(inplanes, planes*self.expansion, stride) |  | ||||||
|     elif inplanes != planes*self.expansion: |  | ||||||
|       self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, False) |  | ||||||
|     else: |  | ||||||
|       self.downsample = None |  | ||||||
|     self.out_dim = planes * self.expansion |  | ||||||
|     self.num_conv = 3 |  | ||||||
|  |  | ||||||
|   def forward(self, inputs): |     def __init__(self, inplanes, planes, stride): | ||||||
|  |         super(ResNetBottleneck, self).__init__() | ||||||
|  |         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||||
|  |         self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, True) | ||||||
|  |         self.conv_3x3 = ConvBNReLU(planes, planes, 3, stride, 1, False, True) | ||||||
|  |         self.conv_1x4 = ConvBNReLU( | ||||||
|  |             planes, planes * self.expansion, 1, 1, 0, False, False | ||||||
|  |         ) | ||||||
|  |         if stride == 2: | ||||||
|  |             self.downsample = Downsample(inplanes, planes * self.expansion, stride) | ||||||
|  |         elif inplanes != planes * self.expansion: | ||||||
|  |             self.downsample = ConvBNReLU( | ||||||
|  |                 inplanes, planes * self.expansion, 1, 1, 0, False, False | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             self.downsample = None | ||||||
|  |         self.out_dim = planes * self.expansion | ||||||
|  |         self.num_conv = 3 | ||||||
|  |  | ||||||
|     bottleneck = self.conv_1x1(inputs) |     def forward(self, inputs): | ||||||
|     bottleneck = self.conv_3x3(bottleneck) |  | ||||||
|     bottleneck = self.conv_1x4(bottleneck) |  | ||||||
|  |  | ||||||
|     if self.downsample is not None: |         bottleneck = self.conv_1x1(inputs) | ||||||
|       residual = self.downsample(inputs) |         bottleneck = self.conv_3x3(bottleneck) | ||||||
|     else: |         bottleneck = self.conv_1x4(bottleneck) | ||||||
|       residual = inputs |  | ||||||
|     out = additive_func(residual, bottleneck) |  | ||||||
|     return F.relu(out, inplace=True) |  | ||||||
|  |  | ||||||
|  |         if self.downsample is not None: | ||||||
|  |             residual = self.downsample(inputs) | ||||||
|  |         else: | ||||||
|  |             residual = inputs | ||||||
|  |         out = additive_func(residual, bottleneck) | ||||||
|  |         return F.relu(out, inplace=True) | ||||||
|  |  | ||||||
|  |  | ||||||
| class CifarResNet(nn.Module): | class CifarResNet(nn.Module): | ||||||
|  |     def __init__(self, block_name, depth, num_classes, zero_init_residual): | ||||||
|  |         super(CifarResNet, self).__init__() | ||||||
|  |  | ||||||
|   def __init__(self, block_name, depth, num_classes, zero_init_residual): |         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||||
|     super(CifarResNet, self).__init__() |         if block_name == "ResNetBasicblock": | ||||||
|  |             block = ResNetBasicblock | ||||||
|  |             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||||
|  |             layer_blocks = (depth - 2) // 6 | ||||||
|  |         elif block_name == "ResNetBottleneck": | ||||||
|  |             block = ResNetBottleneck | ||||||
|  |             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||||
|  |             layer_blocks = (depth - 2) // 9 | ||||||
|  |         else: | ||||||
|  |             raise ValueError("invalid block : {:}".format(block_name)) | ||||||
|  |  | ||||||
|     #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model |         self.message = "CifarResNet : Block : {:}, Depth : {:}, Layers for each block : {:}".format( | ||||||
|     if block_name == 'ResNetBasicblock': |             block_name, depth, layer_blocks | ||||||
|       block = ResNetBasicblock |         ) | ||||||
|       assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' |         self.num_classes = num_classes | ||||||
|       layer_blocks = (depth - 2) // 6 |         self.channels = [16] | ||||||
|     elif block_name == 'ResNetBottleneck': |         self.layers = nn.ModuleList([ConvBNReLU(3, 16, 3, 1, 1, False, True)]) | ||||||
|       block = ResNetBottleneck |         for stage in range(3): | ||||||
|       assert (depth - 2) % 9 == 0, 'depth should be one of 164' |             for iL in range(layer_blocks): | ||||||
|       layer_blocks = (depth - 2) // 9 |                 iC = self.channels[-1] | ||||||
|     else: |                 planes = 16 * (2 ** stage) | ||||||
|       raise ValueError('invalid block : {:}'.format(block_name)) |                 stride = 2 if stage > 0 and iL == 0 else 1 | ||||||
|  |                 module = block(iC, planes, stride) | ||||||
|  |                 self.channels.append(module.out_dim) | ||||||
|  |                 self.layers.append(module) | ||||||
|  |                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( | ||||||
|  |                     stage, | ||||||
|  |                     iL, | ||||||
|  |                     layer_blocks, | ||||||
|  |                     len(self.layers) - 1, | ||||||
|  |                     iC, | ||||||
|  |                     module.out_dim, | ||||||
|  |                     stride, | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|     self.message     = 'CifarResNet : Block : {:}, Depth : {:}, Layers for each block : {:}'.format(block_name, depth, layer_blocks) |         self.avgpool = nn.AvgPool2d(8) | ||||||
|     self.num_classes = num_classes |         self.classifier = nn.Linear(module.out_dim, num_classes) | ||||||
|     self.channels    = [16] |         assert ( | ||||||
|     self.layers      = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, True) ] ) |             sum(x.num_conv for x in self.layers) + 1 == depth | ||||||
|     for stage in range(3): |         ), "invalid depth check {:} vs {:}".format( | ||||||
|       for iL in range(layer_blocks): |             sum(x.num_conv for x in self.layers) + 1, depth | ||||||
|         iC     = self.channels[-1] |         ) | ||||||
|         planes = 16 * (2**stage) |  | ||||||
|         stride = 2 if stage > 0 and iL == 0 else 1 |  | ||||||
|         module = block(iC, planes, stride) |  | ||||||
|         self.channels.append( module.out_dim ) |  | ||||||
|         self.layers.append  ( module ) |  | ||||||
|         self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride) |  | ||||||
|  |  | ||||||
|     self.avgpool = nn.AvgPool2d(8) |         self.apply(initialize_resnet) | ||||||
|     self.classifier = nn.Linear(module.out_dim, num_classes) |         if zero_init_residual: | ||||||
|     assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) |             for m in self.modules(): | ||||||
|  |                 if isinstance(m, ResNetBasicblock): | ||||||
|  |                     nn.init.constant_(m.conv_b.bn.weight, 0) | ||||||
|  |                 elif isinstance(m, ResNetBottleneck): | ||||||
|  |                     nn.init.constant_(m.conv_1x4.bn.weight, 0) | ||||||
|  |  | ||||||
|     self.apply(initialize_resnet) |     def get_message(self): | ||||||
|     if zero_init_residual: |         return self.message | ||||||
|       for m in self.modules(): |  | ||||||
|         if isinstance(m, ResNetBasicblock): |  | ||||||
|           nn.init.constant_(m.conv_b.bn.weight, 0) |  | ||||||
|         elif isinstance(m, ResNetBottleneck): |  | ||||||
|           nn.init.constant_(m.conv_1x4.bn.weight, 0) |  | ||||||
|  |  | ||||||
|   def get_message(self): |     def forward(self, inputs): | ||||||
|     return self.message |         x = inputs | ||||||
|  |         for i, layer in enumerate(self.layers): | ||||||
|   def forward(self, inputs): |             x = layer(x) | ||||||
|     x = inputs |         features = self.avgpool(x) | ||||||
|     for i, layer in enumerate(self.layers): |         features = features.view(features.size(0), -1) | ||||||
|       x = layer( x ) |         logits = self.classifier(features) | ||||||
|     features = self.avgpool(x) |         return features, logits | ||||||
|     features = features.view(features.size(0), -1) |  | ||||||
|     logits   = self.classifier(features) |  | ||||||
|     return features, logits |  | ||||||
|   | |||||||
| @@ -5,90 +5,111 @@ from .initialization import initialize_resnet | |||||||
|  |  | ||||||
|  |  | ||||||
| class WideBasicblock(nn.Module): | class WideBasicblock(nn.Module): | ||||||
|   def __init__(self, inplanes, planes, stride, dropout=False): |     def __init__(self, inplanes, planes, stride, dropout=False): | ||||||
|     super(WideBasicblock, self).__init__() |         super(WideBasicblock, self).__init__() | ||||||
|  |  | ||||||
|     self.bn_a = nn.BatchNorm2d(inplanes) |         self.bn_a = nn.BatchNorm2d(inplanes) | ||||||
|     self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) |         self.conv_a = nn.Conv2d( | ||||||
|  |             inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     self.bn_b = nn.BatchNorm2d(planes) |         self.bn_b = nn.BatchNorm2d(planes) | ||||||
|     if dropout: |         if dropout: | ||||||
|       self.dropout = nn.Dropout2d(p=0.5, inplace=True) |             self.dropout = nn.Dropout2d(p=0.5, inplace=True) | ||||||
|     else: |         else: | ||||||
|       self.dropout = None |             self.dropout = None | ||||||
|     self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) |         self.conv_b = nn.Conv2d( | ||||||
|  |             planes, planes, kernel_size=3, stride=1, padding=1, bias=False | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     if inplanes != planes: |         if inplanes != planes: | ||||||
|       self.downsample = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False) |             self.downsample = nn.Conv2d( | ||||||
|     else: |                 inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False | ||||||
|       self.downsample = None |             ) | ||||||
|  |         else: | ||||||
|  |             self.downsample = None | ||||||
|  |  | ||||||
|   def forward(self, x): |     def forward(self, x): | ||||||
|  |  | ||||||
|     basicblock = self.bn_a(x) |         basicblock = self.bn_a(x) | ||||||
|     basicblock = F.relu(basicblock) |         basicblock = F.relu(basicblock) | ||||||
|     basicblock = self.conv_a(basicblock) |         basicblock = self.conv_a(basicblock) | ||||||
|  |  | ||||||
|     basicblock = self.bn_b(basicblock) |         basicblock = self.bn_b(basicblock) | ||||||
|     basicblock = F.relu(basicblock) |         basicblock = F.relu(basicblock) | ||||||
|     if self.dropout is not None: |         if self.dropout is not None: | ||||||
|       basicblock = self.dropout(basicblock) |             basicblock = self.dropout(basicblock) | ||||||
|     basicblock = self.conv_b(basicblock) |         basicblock = self.conv_b(basicblock) | ||||||
|  |  | ||||||
|     if self.downsample is not None: |         if self.downsample is not None: | ||||||
|       x = self.downsample(x) |             x = self.downsample(x) | ||||||
|  |  | ||||||
|     return x + basicblock |         return x + basicblock | ||||||
|  |  | ||||||
|  |  | ||||||
| class CifarWideResNet(nn.Module): | class CifarWideResNet(nn.Module): | ||||||
|   """ |     """ | ||||||
|   ResNet optimized for the Cifar dataset, as specified in |     ResNet optimized for the Cifar dataset, as specified in | ||||||
|   https://arxiv.org/abs/1512.03385.pdf |     https://arxiv.org/abs/1512.03385.pdf | ||||||
|   """ |     """ | ||||||
|   def __init__(self, depth, widen_factor, num_classes, dropout): |  | ||||||
|     super(CifarWideResNet, self).__init__() |  | ||||||
|  |  | ||||||
|     #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model |     def __init__(self, depth, widen_factor, num_classes, dropout): | ||||||
|     assert (depth - 4) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' |         super(CifarWideResNet, self).__init__() | ||||||
|     layer_blocks = (depth - 4) // 6 |  | ||||||
|     print ('CifarPreResNet : Depth : {} , Layers for each block : {}'.format(depth, layer_blocks)) |  | ||||||
|  |  | ||||||
|     self.num_classes = num_classes |         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||||
|     self.dropout = dropout |         assert (depth - 4) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||||
|     self.conv_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) |         layer_blocks = (depth - 4) // 6 | ||||||
|  |         print( | ||||||
|  |             "CifarPreResNet : Depth : {} , Layers for each block : {}".format( | ||||||
|  |                 depth, layer_blocks | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     self.message  = 'Wide ResNet : depth={:}, widen_factor={:}, class={:}'.format(depth, widen_factor, num_classes) |         self.num_classes = num_classes | ||||||
|     self.inplanes = 16 |         self.dropout = dropout | ||||||
|     self.stage_1 = self._make_layer(WideBasicblock, 16*widen_factor, layer_blocks, 1) |         self.conv_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) | ||||||
|     self.stage_2 = self._make_layer(WideBasicblock, 32*widen_factor, layer_blocks, 2) |  | ||||||
|     self.stage_3 = self._make_layer(WideBasicblock, 64*widen_factor, layer_blocks, 2) |  | ||||||
|     self.lastact = nn.Sequential(nn.BatchNorm2d(64*widen_factor), nn.ReLU(inplace=True)) |  | ||||||
|     self.avgpool = nn.AvgPool2d(8) |  | ||||||
|     self.classifier = nn.Linear(64*widen_factor, num_classes) |  | ||||||
|  |  | ||||||
|     self.apply(initialize_resnet) |         self.message = "Wide ResNet : depth={:}, widen_factor={:}, class={:}".format( | ||||||
|  |             depth, widen_factor, num_classes | ||||||
|  |         ) | ||||||
|  |         self.inplanes = 16 | ||||||
|  |         self.stage_1 = self._make_layer( | ||||||
|  |             WideBasicblock, 16 * widen_factor, layer_blocks, 1 | ||||||
|  |         ) | ||||||
|  |         self.stage_2 = self._make_layer( | ||||||
|  |             WideBasicblock, 32 * widen_factor, layer_blocks, 2 | ||||||
|  |         ) | ||||||
|  |         self.stage_3 = self._make_layer( | ||||||
|  |             WideBasicblock, 64 * widen_factor, layer_blocks, 2 | ||||||
|  |         ) | ||||||
|  |         self.lastact = nn.Sequential( | ||||||
|  |             nn.BatchNorm2d(64 * widen_factor), nn.ReLU(inplace=True) | ||||||
|  |         ) | ||||||
|  |         self.avgpool = nn.AvgPool2d(8) | ||||||
|  |         self.classifier = nn.Linear(64 * widen_factor, num_classes) | ||||||
|  |  | ||||||
|   def get_message(self): |         self.apply(initialize_resnet) | ||||||
|     return self.message |  | ||||||
|  |  | ||||||
|   def _make_layer(self, block, planes, blocks, stride): |     def get_message(self): | ||||||
|  |         return self.message | ||||||
|  |  | ||||||
|     layers = [] |     def _make_layer(self, block, planes, blocks, stride): | ||||||
|     layers.append(block(self.inplanes, planes, stride, self.dropout)) |  | ||||||
|     self.inplanes = planes |  | ||||||
|     for i in range(1, blocks): |  | ||||||
|       layers.append(block(self.inplanes, planes, 1, self.dropout)) |  | ||||||
|  |  | ||||||
|     return nn.Sequential(*layers) |         layers = [] | ||||||
|  |         layers.append(block(self.inplanes, planes, stride, self.dropout)) | ||||||
|  |         self.inplanes = planes | ||||||
|  |         for i in range(1, blocks): | ||||||
|  |             layers.append(block(self.inplanes, planes, 1, self.dropout)) | ||||||
|  |  | ||||||
|   def forward(self, x): |         return nn.Sequential(*layers) | ||||||
|     x = self.conv_3x3(x) |  | ||||||
|     x = self.stage_1(x) |     def forward(self, x): | ||||||
|     x = self.stage_2(x) |         x = self.conv_3x3(x) | ||||||
|     x = self.stage_3(x) |         x = self.stage_1(x) | ||||||
|     x = self.lastact(x) |         x = self.stage_2(x) | ||||||
|     x = self.avgpool(x) |         x = self.stage_3(x) | ||||||
|     features = x.view(x.size(0), -1) |         x = self.lastact(x) | ||||||
|     outs     = self.classifier(features) |         x = self.avgpool(x) | ||||||
|     return features, outs |         features = x.view(x.size(0), -1) | ||||||
|  |         outs = self.classifier(features) | ||||||
|  |         return features, outs | ||||||
|   | |||||||
| @@ -4,98 +4,114 @@ from .initialization import initialize_resnet | |||||||
|  |  | ||||||
|  |  | ||||||
| class ConvBNReLU(nn.Module): | class ConvBNReLU(nn.Module): | ||||||
|   def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): |     def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | ||||||
|     super(ConvBNReLU, self).__init__() |         super(ConvBNReLU, self).__init__() | ||||||
|     padding = (kernel_size - 1) // 2 |         padding = (kernel_size - 1) // 2 | ||||||
|     self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False) |         self.conv = nn.Conv2d( | ||||||
|     self.bn   = nn.BatchNorm2d(out_planes) |             in_planes, | ||||||
|     self.relu = nn.ReLU6(inplace=True) |             out_planes, | ||||||
|  |             kernel_size, | ||||||
|  |             stride, | ||||||
|  |             padding, | ||||||
|  |             groups=groups, | ||||||
|  |             bias=False, | ||||||
|  |         ) | ||||||
|  |         self.bn = nn.BatchNorm2d(out_planes) | ||||||
|  |         self.relu = nn.ReLU6(inplace=True) | ||||||
|  |  | ||||||
|   def forward(self, x): |     def forward(self, x): | ||||||
|     out = self.conv( x ) |         out = self.conv(x) | ||||||
|     out = self.bn  ( out ) |         out = self.bn(out) | ||||||
|     out = self.relu( out ) |         out = self.relu(out) | ||||||
|     return out |         return out | ||||||
|  |  | ||||||
|  |  | ||||||
| class InvertedResidual(nn.Module): | class InvertedResidual(nn.Module): | ||||||
|   def __init__(self, inp, oup, stride, expand_ratio): |     def __init__(self, inp, oup, stride, expand_ratio): | ||||||
|     super(InvertedResidual, self).__init__() |         super(InvertedResidual, self).__init__() | ||||||
|     self.stride = stride |         self.stride = stride | ||||||
|     assert stride in [1, 2] |         assert stride in [1, 2] | ||||||
|  |  | ||||||
|     hidden_dim = int(round(inp * expand_ratio)) |         hidden_dim = int(round(inp * expand_ratio)) | ||||||
|     self.use_res_connect = self.stride == 1 and inp == oup |         self.use_res_connect = self.stride == 1 and inp == oup | ||||||
|  |  | ||||||
|     layers = [] |         layers = [] | ||||||
|     if expand_ratio != 1: |         if expand_ratio != 1: | ||||||
|       # pw |             # pw | ||||||
|       layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) |             layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) | ||||||
|     layers.extend([ |         layers.extend( | ||||||
|       # dw |             [ | ||||||
|       ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), |                 # dw | ||||||
|       # pw-linear |                 ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), | ||||||
|       nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), |                 # pw-linear | ||||||
|       nn.BatchNorm2d(oup), |                 nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), | ||||||
|     ]) |                 nn.BatchNorm2d(oup), | ||||||
|     self.conv = nn.Sequential(*layers) |             ] | ||||||
|  |         ) | ||||||
|  |         self.conv = nn.Sequential(*layers) | ||||||
|  |  | ||||||
|   def forward(self, x): |     def forward(self, x): | ||||||
|     if self.use_res_connect: |         if self.use_res_connect: | ||||||
|       return x + self.conv(x) |             return x + self.conv(x) | ||||||
|     else: |         else: | ||||||
|       return self.conv(x) |             return self.conv(x) | ||||||
|  |  | ||||||
|  |  | ||||||
| class MobileNetV2(nn.Module): | class MobileNetV2(nn.Module): | ||||||
|   def __init__(self, num_classes, width_mult, input_channel, last_channel, block_name, dropout): |     def __init__( | ||||||
|     super(MobileNetV2, self).__init__() |         self, num_classes, width_mult, input_channel, last_channel, block_name, dropout | ||||||
|     if block_name == 'InvertedResidual': |     ): | ||||||
|       block = InvertedResidual |         super(MobileNetV2, self).__init__() | ||||||
|     else: |         if block_name == "InvertedResidual": | ||||||
|       raise ValueError('invalid block name : {:}'.format(block_name)) |             block = InvertedResidual | ||||||
|     inverted_residual_setting = [ |         else: | ||||||
|       # t, c,  n, s |             raise ValueError("invalid block name : {:}".format(block_name)) | ||||||
|       [1, 16 , 1, 1], |         inverted_residual_setting = [ | ||||||
|       [6, 24 , 2, 2], |             # t, c,  n, s | ||||||
|       [6, 32 , 3, 2], |             [1, 16, 1, 1], | ||||||
|       [6, 64 , 4, 2], |             [6, 24, 2, 2], | ||||||
|       [6, 96 , 3, 1], |             [6, 32, 3, 2], | ||||||
|       [6, 160, 3, 2], |             [6, 64, 4, 2], | ||||||
|       [6, 320, 1, 1], |             [6, 96, 3, 1], | ||||||
|     ] |             [6, 160, 3, 2], | ||||||
|  |             [6, 320, 1, 1], | ||||||
|  |         ] | ||||||
|  |  | ||||||
|     # building first layer |         # building first layer | ||||||
|     input_channel = int(input_channel * width_mult) |         input_channel = int(input_channel * width_mult) | ||||||
|     self.last_channel = int(last_channel * max(1.0, width_mult)) |         self.last_channel = int(last_channel * max(1.0, width_mult)) | ||||||
|     features = [ConvBNReLU(3, input_channel, stride=2)] |         features = [ConvBNReLU(3, input_channel, stride=2)] | ||||||
|     # building inverted residual blocks |         # building inverted residual blocks | ||||||
|     for t, c, n, s in inverted_residual_setting: |         for t, c, n, s in inverted_residual_setting: | ||||||
|       output_channel = int(c * width_mult) |             output_channel = int(c * width_mult) | ||||||
|       for i in range(n): |             for i in range(n): | ||||||
|         stride = s if i == 0 else 1 |                 stride = s if i == 0 else 1 | ||||||
|         features.append(block(input_channel, output_channel, stride, expand_ratio=t)) |                 features.append( | ||||||
|         input_channel = output_channel |                     block(input_channel, output_channel, stride, expand_ratio=t) | ||||||
|     # building last several layers |                 ) | ||||||
|     features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) |                 input_channel = output_channel | ||||||
|     # make it nn.Sequential |         # building last several layers | ||||||
|     self.features = nn.Sequential(*features) |         features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) | ||||||
|  |         # make it nn.Sequential | ||||||
|  |         self.features = nn.Sequential(*features) | ||||||
|  |  | ||||||
|     # building classifier |         # building classifier | ||||||
|     self.classifier = nn.Sequential( |         self.classifier = nn.Sequential( | ||||||
|       nn.Dropout(dropout), |             nn.Dropout(dropout), | ||||||
|       nn.Linear(self.last_channel, num_classes), |             nn.Linear(self.last_channel, num_classes), | ||||||
|     ) |         ) | ||||||
|     self.message = 'MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}'.format(width_mult, input_channel, last_channel, block_name, dropout) |         self.message = "MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}".format( | ||||||
|  |             width_mult, input_channel, last_channel, block_name, dropout | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     # weight initialization |         # weight initialization | ||||||
|     self.apply( initialize_resnet ) |         self.apply(initialize_resnet) | ||||||
|  |  | ||||||
|   def get_message(self): |     def get_message(self): | ||||||
|     return self.message |         return self.message | ||||||
|  |  | ||||||
|   def forward(self, inputs): |     def forward(self, inputs): | ||||||
|     features = self.features(inputs) |         features = self.features(inputs) | ||||||
|     vectors  = features.mean([2, 3]) |         vectors = features.mean([2, 3]) | ||||||
|     predicts = self.classifier(vectors) |         predicts = self.classifier(vectors) | ||||||
|     return features, predicts |         return features, predicts | ||||||
|   | |||||||
| @@ -2,171 +2,216 @@ | |||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from .initialization import initialize_resnet | from .initialization import initialize_resnet | ||||||
|  |  | ||||||
|  |  | ||||||
| def conv3x3(in_planes, out_planes, stride=1, groups=1): | def conv3x3(in_planes, out_planes, stride=1, groups=1): | ||||||
|   return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) |     return nn.Conv2d( | ||||||
|  |         in_planes, | ||||||
|  |         out_planes, | ||||||
|  |         kernel_size=3, | ||||||
|  |         stride=stride, | ||||||
|  |         padding=1, | ||||||
|  |         groups=groups, | ||||||
|  |         bias=False, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def conv1x1(in_planes, out_planes, stride=1): | def conv1x1(in_planes, out_planes, stride=1): | ||||||
|   return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) |     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | ||||||
|  |  | ||||||
|  |  | ||||||
| class BasicBlock(nn.Module): | class BasicBlock(nn.Module): | ||||||
|   expansion = 1 |     expansion = 1 | ||||||
|  |  | ||||||
|   def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64): |     def __init__( | ||||||
|     super(BasicBlock, self).__init__() |         self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64 | ||||||
|     if groups != 1 or base_width != 64: |     ): | ||||||
|       raise ValueError('BasicBlock only supports groups=1 and base_width=64') |         super(BasicBlock, self).__init__() | ||||||
|     # Both self.conv1 and self.downsample layers downsample the input when stride != 1 |         if groups != 1 or base_width != 64: | ||||||
|     self.conv1 = conv3x3(inplanes, planes, stride) |             raise ValueError("BasicBlock only supports groups=1 and base_width=64") | ||||||
|     self.bn1   = nn.BatchNorm2d(planes) |         # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | ||||||
|     self.relu  = nn.ReLU(inplace=True) |         self.conv1 = conv3x3(inplanes, planes, stride) | ||||||
|     self.conv2 = conv3x3(planes, planes) |         self.bn1 = nn.BatchNorm2d(planes) | ||||||
|     self.bn2   = nn.BatchNorm2d(planes) |         self.relu = nn.ReLU(inplace=True) | ||||||
|     self.downsample = downsample |         self.conv2 = conv3x3(planes, planes) | ||||||
|     self.stride = stride |         self.bn2 = nn.BatchNorm2d(planes) | ||||||
|  |         self.downsample = downsample | ||||||
|  |         self.stride = stride | ||||||
|  |  | ||||||
|   def forward(self, x): |     def forward(self, x): | ||||||
|     identity = x |         identity = x | ||||||
|  |  | ||||||
|     out = self.conv1(x) |         out = self.conv1(x) | ||||||
|     out = self.bn1(out) |         out = self.bn1(out) | ||||||
|     out = self.relu(out) |         out = self.relu(out) | ||||||
|  |  | ||||||
|     out = self.conv2(out) |         out = self.conv2(out) | ||||||
|     out = self.bn2(out) |         out = self.bn2(out) | ||||||
|  |  | ||||||
|     if self.downsample is not None: |         if self.downsample is not None: | ||||||
|       identity = self.downsample(x) |             identity = self.downsample(x) | ||||||
|  |  | ||||||
|     out += identity |         out += identity | ||||||
|     out = self.relu(out) |         out = self.relu(out) | ||||||
|  |  | ||||||
|     return out |         return out | ||||||
|  |  | ||||||
|  |  | ||||||
| class Bottleneck(nn.Module): | class Bottleneck(nn.Module): | ||||||
|   expansion = 4 |     expansion = 4 | ||||||
|  |  | ||||||
|   def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64): |     def __init__( | ||||||
|     super(Bottleneck, self).__init__() |         self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64 | ||||||
|     width = int(planes * (base_width / 64.)) * groups |     ): | ||||||
|     # Both self.conv2 and self.downsample layers downsample the input when stride != 1 |         super(Bottleneck, self).__init__() | ||||||
|     self.conv1 = conv1x1(inplanes, width) |         width = int(planes * (base_width / 64.0)) * groups | ||||||
|     self.bn1   = nn.BatchNorm2d(width) |         # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | ||||||
|     self.conv2 = conv3x3(width, width, stride, groups) |         self.conv1 = conv1x1(inplanes, width) | ||||||
|     self.bn2   = nn.BatchNorm2d(width) |         self.bn1 = nn.BatchNorm2d(width) | ||||||
|     self.conv3 = conv1x1(width, planes * self.expansion) |         self.conv2 = conv3x3(width, width, stride, groups) | ||||||
|     self.bn3   = nn.BatchNorm2d(planes * self.expansion) |         self.bn2 = nn.BatchNorm2d(width) | ||||||
|     self.relu  = nn.ReLU(inplace=True) |         self.conv3 = conv1x1(width, planes * self.expansion) | ||||||
|     self.downsample = downsample |         self.bn3 = nn.BatchNorm2d(planes * self.expansion) | ||||||
|     self.stride = stride |         self.relu = nn.ReLU(inplace=True) | ||||||
|  |         self.downsample = downsample | ||||||
|  |         self.stride = stride | ||||||
|  |  | ||||||
|   def forward(self, x): |     def forward(self, x): | ||||||
|     identity = x |         identity = x | ||||||
|  |  | ||||||
|     out = self.conv1(x) |         out = self.conv1(x) | ||||||
|     out = self.bn1(out) |         out = self.bn1(out) | ||||||
|     out = self.relu(out) |         out = self.relu(out) | ||||||
|  |  | ||||||
|     out = self.conv2(out) |         out = self.conv2(out) | ||||||
|     out = self.bn2(out) |         out = self.bn2(out) | ||||||
|     out = self.relu(out) |         out = self.relu(out) | ||||||
|  |  | ||||||
|     out = self.conv3(out) |         out = self.conv3(out) | ||||||
|     out = self.bn3(out) |         out = self.bn3(out) | ||||||
|  |  | ||||||
|     if self.downsample is not None: |         if self.downsample is not None: | ||||||
|       identity = self.downsample(x) |             identity = self.downsample(x) | ||||||
|  |  | ||||||
|     out += identity |         out += identity | ||||||
|     out = self.relu(out) |         out = self.relu(out) | ||||||
|  |  | ||||||
|     return out |         return out | ||||||
|  |  | ||||||
|  |  | ||||||
| class ResNet(nn.Module): | class ResNet(nn.Module): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         block_name, | ||||||
|  |         layers, | ||||||
|  |         deep_stem, | ||||||
|  |         num_classes, | ||||||
|  |         zero_init_residual, | ||||||
|  |         groups, | ||||||
|  |         width_per_group, | ||||||
|  |     ): | ||||||
|  |         super(ResNet, self).__init__() | ||||||
|  |  | ||||||
|   def __init__(self, block_name, layers, deep_stem, num_classes, zero_init_residual, groups, width_per_group): |         # planes = [int(width_per_group * groups * 2 ** i) for i in range(4)] | ||||||
|     super(ResNet, self).__init__() |         if block_name == "BasicBlock": | ||||||
|  |             block = BasicBlock | ||||||
|  |         elif block_name == "Bottleneck": | ||||||
|  |             block = Bottleneck | ||||||
|  |         else: | ||||||
|  |             raise ValueError("invalid block-name : {:}".format(block_name)) | ||||||
|  |  | ||||||
|     #planes = [int(width_per_group * groups * 2 ** i) for i in range(4)] |         if not deep_stem: | ||||||
|     if block_name == 'BasicBlock'  : block= BasicBlock |             self.conv = nn.Sequential( | ||||||
|     elif block_name == 'Bottleneck': block= Bottleneck |                 nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), | ||||||
|     else                           : raise ValueError('invalid block-name : {:}'.format(block_name)) |                 nn.BatchNorm2d(64), | ||||||
|  |                 nn.ReLU(inplace=True), | ||||||
|     if not deep_stem: |             ) | ||||||
|       self.conv = nn.Sequential( |         else: | ||||||
|                    nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), |             self.conv = nn.Sequential( | ||||||
|                    nn.BatchNorm2d(64), nn.ReLU(inplace=True)) |                 nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False), | ||||||
|     else: |                 nn.BatchNorm2d(32), | ||||||
|       self.conv = nn.Sequential( |                 nn.ReLU(inplace=True), | ||||||
|                    nn.Conv2d(           3, 32, kernel_size=3, stride=2, padding=1, bias=False), |                 nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False), | ||||||
|                    nn.BatchNorm2d(32), nn.ReLU(inplace=True), |                 nn.BatchNorm2d(32), | ||||||
|                    nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False), |                 nn.ReLU(inplace=True), | ||||||
|                    nn.BatchNorm2d(32), nn.ReLU(inplace=True), |                 nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False), | ||||||
|                    nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False), |                 nn.BatchNorm2d(64), | ||||||
|                    nn.BatchNorm2d(64), nn.ReLU(inplace=True)) |                 nn.ReLU(inplace=True), | ||||||
|     self.inplanes = 64 |             ) | ||||||
|     self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |         self.inplanes = 64 | ||||||
|     self.layer1 = self._make_layer(block, 64 , layers[0], stride=1, groups=groups, base_width=width_per_group) |         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||||||
|     self.layer2 = self._make_layer(block, 128, layers[1], stride=2, groups=groups, base_width=width_per_group) |         self.layer1 = self._make_layer( | ||||||
|     self.layer3 = self._make_layer(block, 256, layers[2], stride=2, groups=groups, base_width=width_per_group) |             block, 64, layers[0], stride=1, groups=groups, base_width=width_per_group | ||||||
|     self.layer4 = self._make_layer(block, 512, layers[3], stride=2, groups=groups, base_width=width_per_group) |  | ||||||
|     self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |  | ||||||
|     self.fc      = nn.Linear(512 * block.expansion, num_classes) |  | ||||||
|     self.message = 'block = {:}, layers = {:}, deep_stem = {:}, num_classes = {:}'.format(block, layers, deep_stem, num_classes) |  | ||||||
|  |  | ||||||
|     self.apply( initialize_resnet ) |  | ||||||
|  |  | ||||||
|     # Zero-initialize the last BN in each residual branch, |  | ||||||
|     # so that the residual branch starts with zeros, and each residual block behaves like an identity. |  | ||||||
|     # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 |  | ||||||
|     if zero_init_residual: |  | ||||||
|       for m in self.modules(): |  | ||||||
|         if isinstance(m, Bottleneck): |  | ||||||
|           nn.init.constant_(m.bn3.weight, 0) |  | ||||||
|         elif isinstance(m, BasicBlock): |  | ||||||
|           nn.init.constant_(m.bn2.weight, 0) |  | ||||||
|  |  | ||||||
|   def _make_layer(self, block, planes, blocks, stride, groups, base_width): |  | ||||||
|     downsample = None |  | ||||||
|     if stride != 1 or self.inplanes != planes * block.expansion: |  | ||||||
|       if stride == 2: |  | ||||||
|         downsample = nn.Sequential( |  | ||||||
|           nn.AvgPool2d(kernel_size=2, stride=2, padding=0), |  | ||||||
|           conv1x1(self.inplanes, planes * block.expansion, 1), |  | ||||||
|           nn.BatchNorm2d(planes * block.expansion), |  | ||||||
|         ) |         ) | ||||||
|       elif stride == 1: |         self.layer2 = self._make_layer( | ||||||
|         downsample = nn.Sequential( |             block, 128, layers[1], stride=2, groups=groups, base_width=width_per_group | ||||||
|           conv1x1(self.inplanes, planes * block.expansion, stride), |         ) | ||||||
|           nn.BatchNorm2d(planes * block.expansion), |         self.layer3 = self._make_layer( | ||||||
|  |             block, 256, layers[2], stride=2, groups=groups, base_width=width_per_group | ||||||
|  |         ) | ||||||
|  |         self.layer4 = self._make_layer( | ||||||
|  |             block, 512, layers[3], stride=2, groups=groups, base_width=width_per_group | ||||||
|  |         ) | ||||||
|  |         self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||||||
|  |         self.fc = nn.Linear(512 * block.expansion, num_classes) | ||||||
|  |         self.message = ( | ||||||
|  |             "block = {:}, layers = {:}, deep_stem = {:}, num_classes = {:}".format( | ||||||
|  |                 block, layers, deep_stem, num_classes | ||||||
|  |             ) | ||||||
|         ) |         ) | ||||||
|       else: raise ValueError('invalid stride [{:}] for downsample'.format(stride)) |  | ||||||
|  |  | ||||||
|     layers = [] |         self.apply(initialize_resnet) | ||||||
|     layers.append(block(self.inplanes, planes, stride, downsample, groups, base_width)) |  | ||||||
|     self.inplanes = planes * block.expansion |  | ||||||
|     for _ in range(1, blocks): |  | ||||||
|       layers.append(block(self.inplanes, planes, 1, None, groups, base_width)) |  | ||||||
|  |  | ||||||
|     return nn.Sequential(*layers) |         # Zero-initialize the last BN in each residual branch, | ||||||
|  |         # so that the residual branch starts with zeros, and each residual block behaves like an identity. | ||||||
|  |         # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | ||||||
|  |         if zero_init_residual: | ||||||
|  |             for m in self.modules(): | ||||||
|  |                 if isinstance(m, Bottleneck): | ||||||
|  |                     nn.init.constant_(m.bn3.weight, 0) | ||||||
|  |                 elif isinstance(m, BasicBlock): | ||||||
|  |                     nn.init.constant_(m.bn2.weight, 0) | ||||||
|  |  | ||||||
|   def get_message(self): |     def _make_layer(self, block, planes, blocks, stride, groups, base_width): | ||||||
|     return self.message |         downsample = None | ||||||
|  |         if stride != 1 or self.inplanes != planes * block.expansion: | ||||||
|  |             if stride == 2: | ||||||
|  |                 downsample = nn.Sequential( | ||||||
|  |                     nn.AvgPool2d(kernel_size=2, stride=2, padding=0), | ||||||
|  |                     conv1x1(self.inplanes, planes * block.expansion, 1), | ||||||
|  |                     nn.BatchNorm2d(planes * block.expansion), | ||||||
|  |                 ) | ||||||
|  |             elif stride == 1: | ||||||
|  |                 downsample = nn.Sequential( | ||||||
|  |                     conv1x1(self.inplanes, planes * block.expansion, stride), | ||||||
|  |                     nn.BatchNorm2d(planes * block.expansion), | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|  |                 raise ValueError("invalid stride [{:}] for downsample".format(stride)) | ||||||
|  |  | ||||||
|   def forward(self, x): |         layers = [] | ||||||
|     x = self.conv(x) |         layers.append( | ||||||
|     x = self.maxpool(x) |             block(self.inplanes, planes, stride, downsample, groups, base_width) | ||||||
|  |         ) | ||||||
|  |         self.inplanes = planes * block.expansion | ||||||
|  |         for _ in range(1, blocks): | ||||||
|  |             layers.append(block(self.inplanes, planes, 1, None, groups, base_width)) | ||||||
|  |  | ||||||
|     x = self.layer1(x) |         return nn.Sequential(*layers) | ||||||
|     x = self.layer2(x) |  | ||||||
|     x = self.layer3(x) |  | ||||||
|     x = self.layer4(x) |  | ||||||
|  |  | ||||||
|     features = self.avgpool(x) |     def get_message(self): | ||||||
|     features = features.view(features.size(0), -1) |         return self.message | ||||||
|     logits   = self.fc(features) |  | ||||||
|  |  | ||||||
|     return features, logits |     def forward(self, x): | ||||||
|  |         x = self.conv(x) | ||||||
|  |         x = self.maxpool(x) | ||||||
|  |  | ||||||
|  |         x = self.layer1(x) | ||||||
|  |         x = self.layer2(x) | ||||||
|  |         x = self.layer3(x) | ||||||
|  |         x = self.layer4(x) | ||||||
|  |  | ||||||
|  |         features = self.avgpool(x) | ||||||
|  |         features = features.view(features.size(0), -1) | ||||||
|  |         logits = self.fc(features) | ||||||
|  |  | ||||||
|  |         return features, logits | ||||||
|   | |||||||
| @@ -6,29 +6,32 @@ import torch.nn as nn | |||||||
|  |  | ||||||
|  |  | ||||||
| def additive_func(A, B): | def additive_func(A, B): | ||||||
|   assert A.dim() == B.dim() and A.size(0) == B.size(0), '{:} vs {:}'.format(A.size(), B.size()) |     assert A.dim() == B.dim() and A.size(0) == B.size(0), "{:} vs {:}".format( | ||||||
|   C = min(A.size(1), B.size(1)) |         A.size(), B.size() | ||||||
|   if A.size(1) == B.size(1): |     ) | ||||||
|     return A + B |     C = min(A.size(1), B.size(1)) | ||||||
|   elif A.size(1) < B.size(1): |     if A.size(1) == B.size(1): | ||||||
|     out = B.clone() |         return A + B | ||||||
|     out[:,:C] += A |     elif A.size(1) < B.size(1): | ||||||
|     return out |         out = B.clone() | ||||||
|   else: |         out[:, :C] += A | ||||||
|     out = A.clone() |         return out | ||||||
|     out[:,:C] += B |     else: | ||||||
|     return out |         out = A.clone() | ||||||
|  |         out[:, :C] += B | ||||||
|  |         return out | ||||||
|  |  | ||||||
|  |  | ||||||
| def change_key(key, value): | def change_key(key, value): | ||||||
|   def func(m): |     def func(m): | ||||||
|     if hasattr(m, key): |         if hasattr(m, key): | ||||||
|       setattr(m, key, value) |             setattr(m, key, value) | ||||||
|   return func |  | ||||||
|  |     return func | ||||||
|  |  | ||||||
|  |  | ||||||
| def parse_channel_info(xstring): | def parse_channel_info(xstring): | ||||||
|   blocks = xstring.split(' ') |     blocks = xstring.split(" ") | ||||||
|   blocks = [x.split('-') for x in blocks] |     blocks = [x.split("-") for x in blocks] | ||||||
|   blocks = [[int(_) for _ in x] for x in blocks] |     blocks = [[int(_) for _ in x] for x in blocks] | ||||||
|   return blocks |     return blocks | ||||||
|   | |||||||
| @@ -5,10 +5,18 @@ from os import path as osp | |||||||
| from typing import List, Text | from typing import List, Text | ||||||
| import torch | import torch | ||||||
|  |  | ||||||
| __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \ | __all__ = [ | ||||||
|            'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \ |     "change_key", | ||||||
|            'CellStructure', 'CellArchitectures' |     "get_cell_based_tiny_net", | ||||||
|            ] |     "get_search_spaces", | ||||||
|  |     "get_cifar_models", | ||||||
|  |     "get_imagenet_models", | ||||||
|  |     "obtain_model", | ||||||
|  |     "obtain_search_model", | ||||||
|  |     "load_net_from_checkpoint", | ||||||
|  |     "CellStructure", | ||||||
|  |     "CellArchitectures", | ||||||
|  | ] | ||||||
|  |  | ||||||
| # useful modules | # useful modules | ||||||
| from config_utils import dict2config | from config_utils import dict2config | ||||||
| @@ -18,178 +26,301 @@ from models.cell_searchs import CellStructure, CellArchitectures | |||||||
|  |  | ||||||
| # Cell-based NAS Models | # Cell-based NAS Models | ||||||
| def get_cell_based_tiny_net(config): | def get_cell_based_tiny_net(config): | ||||||
|   if isinstance(config, dict): config = dict2config(config, None) # to support the argument being a dict |     if isinstance(config, dict): | ||||||
|   super_type = getattr(config, 'super_type', 'basic') |         config = dict2config(config, None)  # to support the argument being a dict | ||||||
|   group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM', 'generic'] |     super_type = getattr(config, "super_type", "basic") | ||||||
|   if super_type == 'basic' and config.name in group_names: |     group_names = ["DARTS-V1", "DARTS-V2", "GDAS", "SETN", "ENAS", "RANDOM", "generic"] | ||||||
|     from .cell_searchs import nas201_super_nets as nas_super_nets |     if super_type == "basic" and config.name in group_names: | ||||||
|     try: |         from .cell_searchs import nas201_super_nets as nas_super_nets | ||||||
|       return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats) |  | ||||||
|     except: |         try: | ||||||
|       return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space) |             return nas_super_nets[config.name]( | ||||||
|   elif super_type == 'search-shape': |                 config.C, | ||||||
|     from .shape_searchs import GenericNAS301Model |                 config.N, | ||||||
|     genotype = CellStructure.str2structure(config.genotype) |                 config.max_nodes, | ||||||
|     return GenericNAS301Model(config.candidate_Cs, config.max_num_Cs, genotype, config.num_classes, config.affine, config.track_running_stats) |                 config.num_classes, | ||||||
|   elif super_type == 'nasnet-super': |                 config.space, | ||||||
|     from .cell_searchs import nasnet_super_nets as nas_super_nets |                 config.affine, | ||||||
|     return nas_super_nets[config.name](config.C, config.N, config.steps, config.multiplier, \ |                 config.track_running_stats, | ||||||
|                     config.stem_multiplier, config.num_classes, config.space, config.affine, config.track_running_stats) |             ) | ||||||
|   elif config.name == 'infer.tiny': |         except: | ||||||
|     from .cell_infers import TinyNetwork |             return nas_super_nets[config.name]( | ||||||
|     if hasattr(config, 'genotype'): |                 config.C, config.N, config.max_nodes, config.num_classes, config.space | ||||||
|       genotype = config.genotype |             ) | ||||||
|     elif hasattr(config, 'arch_str'): |     elif super_type == "search-shape": | ||||||
|       genotype = CellStructure.str2structure(config.arch_str) |         from .shape_searchs import GenericNAS301Model | ||||||
|     else: raise ValueError('Can not find genotype from this config : {:}'.format(config)) |  | ||||||
|     return TinyNetwork(config.C, config.N, genotype, config.num_classes) |         genotype = CellStructure.str2structure(config.genotype) | ||||||
|   elif config.name == 'infer.shape.tiny': |         return GenericNAS301Model( | ||||||
|     from .shape_infers import DynamicShapeTinyNet |             config.candidate_Cs, | ||||||
|     if isinstance(config.channels, str): |             config.max_num_Cs, | ||||||
|       channels = tuple([int(x) for x in config.channels.split(':')]) |             genotype, | ||||||
|     else: channels = config.channels |             config.num_classes, | ||||||
|     genotype = CellStructure.str2structure(config.genotype) |             config.affine, | ||||||
|     return DynamicShapeTinyNet(channels, genotype, config.num_classes) |             config.track_running_stats, | ||||||
|   elif config.name == 'infer.nasnet-cifar': |         ) | ||||||
|     from .cell_infers import NASNetonCIFAR |     elif super_type == "nasnet-super": | ||||||
|     raise NotImplementedError |         from .cell_searchs import nasnet_super_nets as nas_super_nets | ||||||
|   else: |  | ||||||
|     raise ValueError('invalid network name : {:}'.format(config.name)) |         return nas_super_nets[config.name]( | ||||||
|  |             config.C, | ||||||
|  |             config.N, | ||||||
|  |             config.steps, | ||||||
|  |             config.multiplier, | ||||||
|  |             config.stem_multiplier, | ||||||
|  |             config.num_classes, | ||||||
|  |             config.space, | ||||||
|  |             config.affine, | ||||||
|  |             config.track_running_stats, | ||||||
|  |         ) | ||||||
|  |     elif config.name == "infer.tiny": | ||||||
|  |         from .cell_infers import TinyNetwork | ||||||
|  |  | ||||||
|  |         if hasattr(config, "genotype"): | ||||||
|  |             genotype = config.genotype | ||||||
|  |         elif hasattr(config, "arch_str"): | ||||||
|  |             genotype = CellStructure.str2structure(config.arch_str) | ||||||
|  |         else: | ||||||
|  |             raise ValueError( | ||||||
|  |                 "Can not find genotype from this config : {:}".format(config) | ||||||
|  |             ) | ||||||
|  |         return TinyNetwork(config.C, config.N, genotype, config.num_classes) | ||||||
|  |     elif config.name == "infer.shape.tiny": | ||||||
|  |         from .shape_infers import DynamicShapeTinyNet | ||||||
|  |  | ||||||
|  |         if isinstance(config.channels, str): | ||||||
|  |             channels = tuple([int(x) for x in config.channels.split(":")]) | ||||||
|  |         else: | ||||||
|  |             channels = config.channels | ||||||
|  |         genotype = CellStructure.str2structure(config.genotype) | ||||||
|  |         return DynamicShapeTinyNet(channels, genotype, config.num_classes) | ||||||
|  |     elif config.name == "infer.nasnet-cifar": | ||||||
|  |         from .cell_infers import NASNetonCIFAR | ||||||
|  |  | ||||||
|  |         raise NotImplementedError | ||||||
|  |     else: | ||||||
|  |         raise ValueError("invalid network name : {:}".format(config.name)) | ||||||
|  |  | ||||||
|  |  | ||||||
| # obtain the search space, i.e., a dict mapping the operation name into a python-function for this op | # obtain the search space, i.e., a dict mapping the operation name into a python-function for this op | ||||||
| def get_search_spaces(xtype, name) -> List[Text]: | def get_search_spaces(xtype, name) -> List[Text]: | ||||||
|   if xtype == 'cell' or xtype == 'tss':  # The topology search space. |     if xtype == "cell" or xtype == "tss":  # The topology search space. | ||||||
|     from .cell_operations import SearchSpaceNames |         from .cell_operations import SearchSpaceNames | ||||||
|     assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys()) |  | ||||||
|     return SearchSpaceNames[name] |         assert name in SearchSpaceNames, "invalid name [{:}] in {:}".format( | ||||||
|   elif xtype == 'sss':  # The size search space. |             name, SearchSpaceNames.keys() | ||||||
|     if name in ['nats-bench', 'nats-bench-size']: |         ) | ||||||
|       return {'candidates': [8, 16, 24, 32, 40, 48, 56, 64], |         return SearchSpaceNames[name] | ||||||
|               'numbers': 5} |     elif xtype == "sss":  # The size search space. | ||||||
|  |         if name in ["nats-bench", "nats-bench-size"]: | ||||||
|  |             return {"candidates": [8, 16, 24, 32, 40, 48, 56, 64], "numbers": 5} | ||||||
|  |         else: | ||||||
|  |             raise ValueError("Invalid name : {:}".format(name)) | ||||||
|     else: |     else: | ||||||
|       raise ValueError('Invalid name : {:}'.format(name)) |         raise ValueError("invalid search-space type is {:}".format(xtype)) | ||||||
|   else: |  | ||||||
|     raise ValueError('invalid search-space type is {:}'.format(xtype)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_cifar_models(config, extra_path=None): | def get_cifar_models(config, extra_path=None): | ||||||
|   super_type = getattr(config, 'super_type', 'basic') |     super_type = getattr(config, "super_type", "basic") | ||||||
|   if super_type == 'basic': |     if super_type == "basic": | ||||||
|     from .CifarResNet      import CifarResNet |         from .CifarResNet import CifarResNet | ||||||
|     from .CifarDenseNet    import DenseNet |         from .CifarDenseNet import DenseNet | ||||||
|     from .CifarWideResNet  import CifarWideResNet |         from .CifarWideResNet import CifarWideResNet | ||||||
|     if config.arch == 'resnet': |  | ||||||
|       return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual) |         if config.arch == "resnet": | ||||||
|     elif config.arch == 'densenet': |             return CifarResNet( | ||||||
|       return DenseNet(config.growthRate, config.depth, config.reduction, config.class_num, config.bottleneck) |                 config.module, config.depth, config.class_num, config.zero_init_residual | ||||||
|     elif config.arch == 'wideresnet': |             ) | ||||||
|       return CifarWideResNet(config.depth, config.wide_factor, config.class_num, config.dropout) |         elif config.arch == "densenet": | ||||||
|  |             return DenseNet( | ||||||
|  |                 config.growthRate, | ||||||
|  |                 config.depth, | ||||||
|  |                 config.reduction, | ||||||
|  |                 config.class_num, | ||||||
|  |                 config.bottleneck, | ||||||
|  |             ) | ||||||
|  |         elif config.arch == "wideresnet": | ||||||
|  |             return CifarWideResNet( | ||||||
|  |                 config.depth, config.wide_factor, config.class_num, config.dropout | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             raise ValueError("invalid module type : {:}".format(config.arch)) | ||||||
|  |     elif super_type.startswith("infer"): | ||||||
|  |         from .shape_infers import InferWidthCifarResNet | ||||||
|  |         from .shape_infers import InferDepthCifarResNet | ||||||
|  |         from .shape_infers import InferCifarResNet | ||||||
|  |         from .cell_infers import NASNetonCIFAR | ||||||
|  |  | ||||||
|  |         assert len(super_type.split("-")) == 2, "invalid super_type : {:}".format( | ||||||
|  |             super_type | ||||||
|  |         ) | ||||||
|  |         infer_mode = super_type.split("-")[1] | ||||||
|  |         if infer_mode == "width": | ||||||
|  |             return InferWidthCifarResNet( | ||||||
|  |                 config.module, | ||||||
|  |                 config.depth, | ||||||
|  |                 config.xchannels, | ||||||
|  |                 config.class_num, | ||||||
|  |                 config.zero_init_residual, | ||||||
|  |             ) | ||||||
|  |         elif infer_mode == "depth": | ||||||
|  |             return InferDepthCifarResNet( | ||||||
|  |                 config.module, | ||||||
|  |                 config.depth, | ||||||
|  |                 config.xblocks, | ||||||
|  |                 config.class_num, | ||||||
|  |                 config.zero_init_residual, | ||||||
|  |             ) | ||||||
|  |         elif infer_mode == "shape": | ||||||
|  |             return InferCifarResNet( | ||||||
|  |                 config.module, | ||||||
|  |                 config.depth, | ||||||
|  |                 config.xblocks, | ||||||
|  |                 config.xchannels, | ||||||
|  |                 config.class_num, | ||||||
|  |                 config.zero_init_residual, | ||||||
|  |             ) | ||||||
|  |         elif infer_mode == "nasnet.cifar": | ||||||
|  |             genotype = config.genotype | ||||||
|  |             if extra_path is not None:  # reload genotype by extra_path | ||||||
|  |                 if not osp.isfile(extra_path): | ||||||
|  |                     raise ValueError("invalid extra_path : {:}".format(extra_path)) | ||||||
|  |                 xdata = torch.load(extra_path) | ||||||
|  |                 current_epoch = xdata["epoch"] | ||||||
|  |                 genotype = xdata["genotypes"][current_epoch - 1] | ||||||
|  |             C = config.C if hasattr(config, "C") else config.ichannel | ||||||
|  |             N = config.N if hasattr(config, "N") else config.layers | ||||||
|  |             return NASNetonCIFAR( | ||||||
|  |                 C, N, config.stem_multi, config.class_num, genotype, config.auxiliary | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             raise ValueError("invalid infer-mode : {:}".format(infer_mode)) | ||||||
|     else: |     else: | ||||||
|       raise ValueError('invalid module type : {:}'.format(config.arch)) |         raise ValueError("invalid super-type : {:}".format(super_type)) | ||||||
|   elif super_type.startswith('infer'): |  | ||||||
|     from .shape_infers import InferWidthCifarResNet |  | ||||||
|     from .shape_infers import InferDepthCifarResNet |  | ||||||
|     from .shape_infers import InferCifarResNet |  | ||||||
|     from .cell_infers import NASNetonCIFAR |  | ||||||
|     assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type) |  | ||||||
|     infer_mode = super_type.split('-')[1] |  | ||||||
|     if infer_mode == 'width': |  | ||||||
|       return InferWidthCifarResNet(config.module, config.depth, config.xchannels, config.class_num, config.zero_init_residual) |  | ||||||
|     elif infer_mode == 'depth': |  | ||||||
|       return InferDepthCifarResNet(config.module, config.depth, config.xblocks, config.class_num, config.zero_init_residual) |  | ||||||
|     elif infer_mode == 'shape': |  | ||||||
|       return InferCifarResNet(config.module, config.depth, config.xblocks, config.xchannels, config.class_num, config.zero_init_residual) |  | ||||||
|     elif infer_mode == 'nasnet.cifar': |  | ||||||
|       genotype = config.genotype |  | ||||||
|       if extra_path is not None:  # reload genotype by extra_path |  | ||||||
|         if not osp.isfile(extra_path): raise ValueError('invalid extra_path : {:}'.format(extra_path)) |  | ||||||
|         xdata = torch.load(extra_path) |  | ||||||
|         current_epoch = xdata['epoch'] |  | ||||||
|         genotype = xdata['genotypes'][current_epoch-1] |  | ||||||
|       C = config.C if hasattr(config, 'C') else config.ichannel |  | ||||||
|       N = config.N if hasattr(config, 'N') else config.layers |  | ||||||
|       return NASNetonCIFAR(C, N, config.stem_multi, config.class_num, genotype, config.auxiliary) |  | ||||||
|     else: |  | ||||||
|       raise ValueError('invalid infer-mode : {:}'.format(infer_mode)) |  | ||||||
|   else: |  | ||||||
|     raise ValueError('invalid super-type : {:}'.format(super_type)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_imagenet_models(config): | def get_imagenet_models(config): | ||||||
|   super_type = getattr(config, 'super_type', 'basic') |     super_type = getattr(config, "super_type", "basic") | ||||||
|   if super_type == 'basic': |     if super_type == "basic": | ||||||
|     from .ImageNet_ResNet import ResNet |         from .ImageNet_ResNet import ResNet | ||||||
|     from .ImageNet_MobileNetV2 import MobileNetV2 |         from .ImageNet_MobileNetV2 import MobileNetV2 | ||||||
|     if config.arch == 'resnet': |  | ||||||
|       return ResNet(config.block_name, config.layers, config.deep_stem, config.class_num, config.zero_init_residual, config.groups, config.width_per_group) |         if config.arch == "resnet": | ||||||
|     elif config.arch == 'mobilenet_v2': |             return ResNet( | ||||||
|       return MobileNetV2(config.class_num, config.width_multi, config.input_channel, config.last_channel, 'InvertedResidual', config.dropout) |                 config.block_name, | ||||||
|  |                 config.layers, | ||||||
|  |                 config.deep_stem, | ||||||
|  |                 config.class_num, | ||||||
|  |                 config.zero_init_residual, | ||||||
|  |                 config.groups, | ||||||
|  |                 config.width_per_group, | ||||||
|  |             ) | ||||||
|  |         elif config.arch == "mobilenet_v2": | ||||||
|  |             return MobileNetV2( | ||||||
|  |                 config.class_num, | ||||||
|  |                 config.width_multi, | ||||||
|  |                 config.input_channel, | ||||||
|  |                 config.last_channel, | ||||||
|  |                 "InvertedResidual", | ||||||
|  |                 config.dropout, | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             raise ValueError("invalid arch : {:}".format(config.arch)) | ||||||
|  |     elif super_type.startswith("infer"):  # NAS searched architecture | ||||||
|  |         assert len(super_type.split("-")) == 2, "invalid super_type : {:}".format( | ||||||
|  |             super_type | ||||||
|  |         ) | ||||||
|  |         infer_mode = super_type.split("-")[1] | ||||||
|  |         if infer_mode == "shape": | ||||||
|  |             from .shape_infers import InferImagenetResNet | ||||||
|  |             from .shape_infers import InferMobileNetV2 | ||||||
|  |  | ||||||
|  |             if config.arch == "resnet": | ||||||
|  |                 return InferImagenetResNet( | ||||||
|  |                     config.block_name, | ||||||
|  |                     config.layers, | ||||||
|  |                     config.xblocks, | ||||||
|  |                     config.xchannels, | ||||||
|  |                     config.deep_stem, | ||||||
|  |                     config.class_num, | ||||||
|  |                     config.zero_init_residual, | ||||||
|  |                 ) | ||||||
|  |             elif config.arch == "MobileNetV2": | ||||||
|  |                 return InferMobileNetV2( | ||||||
|  |                     config.class_num, config.xchannels, config.xblocks, config.dropout | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|  |                 raise ValueError("invalid arch-mode : {:}".format(config.arch)) | ||||||
|  |         else: | ||||||
|  |             raise ValueError("invalid infer-mode : {:}".format(infer_mode)) | ||||||
|     else: |     else: | ||||||
|       raise ValueError('invalid arch : {:}'.format( config.arch )) |         raise ValueError("invalid super-type : {:}".format(super_type)) | ||||||
|   elif super_type.startswith('infer'): # NAS searched architecture |  | ||||||
|     assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type) |  | ||||||
|     infer_mode = super_type.split('-')[1] |  | ||||||
|     if infer_mode == 'shape': |  | ||||||
|       from .shape_infers import InferImagenetResNet |  | ||||||
|       from .shape_infers import InferMobileNetV2 |  | ||||||
|       if config.arch == 'resnet': |  | ||||||
|         return InferImagenetResNet(config.block_name, config.layers, config.xblocks, config.xchannels, config.deep_stem, config.class_num, config.zero_init_residual) |  | ||||||
|       elif config.arch == "MobileNetV2": |  | ||||||
|         return InferMobileNetV2(config.class_num, config.xchannels, config.xblocks, config.dropout) |  | ||||||
|       else: |  | ||||||
|         raise ValueError('invalid arch-mode : {:}'.format(config.arch)) |  | ||||||
|     else: |  | ||||||
|       raise ValueError('invalid infer-mode : {:}'.format(infer_mode)) |  | ||||||
|   else: |  | ||||||
|     raise ValueError('invalid super-type : {:}'.format(super_type)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # Try to obtain the network by config. | # Try to obtain the network by config. | ||||||
| def obtain_model(config, extra_path=None): | def obtain_model(config, extra_path=None): | ||||||
|   if config.dataset == 'cifar': |     if config.dataset == "cifar": | ||||||
|     return get_cifar_models(config, extra_path) |         return get_cifar_models(config, extra_path) | ||||||
|   elif config.dataset == 'imagenet': |     elif config.dataset == "imagenet": | ||||||
|     return get_imagenet_models(config) |         return get_imagenet_models(config) | ||||||
|   else: |     else: | ||||||
|     raise ValueError('invalid dataset in the model config : {:}'.format(config)) |         raise ValueError("invalid dataset in the model config : {:}".format(config)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def obtain_search_model(config): | def obtain_search_model(config): | ||||||
|   if config.dataset == 'cifar': |     if config.dataset == "cifar": | ||||||
|     if config.arch == 'resnet': |         if config.arch == "resnet": | ||||||
|       from .shape_searchs import SearchWidthCifarResNet |             from .shape_searchs import SearchWidthCifarResNet | ||||||
|       from .shape_searchs import SearchDepthCifarResNet |             from .shape_searchs import SearchDepthCifarResNet | ||||||
|       from .shape_searchs import SearchShapeCifarResNet |             from .shape_searchs import SearchShapeCifarResNet | ||||||
|       if config.search_mode == 'width': |  | ||||||
|         return SearchWidthCifarResNet(config.module, config.depth, config.class_num) |             if config.search_mode == "width": | ||||||
|       elif config.search_mode == 'depth': |                 return SearchWidthCifarResNet( | ||||||
|         return SearchDepthCifarResNet(config.module, config.depth, config.class_num) |                     config.module, config.depth, config.class_num | ||||||
|       elif config.search_mode == 'shape': |                 ) | ||||||
|         return SearchShapeCifarResNet(config.module, config.depth, config.class_num) |             elif config.search_mode == "depth": | ||||||
|       else: raise ValueError('invalid search mode : {:}'.format(config.search_mode)) |                 return SearchDepthCifarResNet( | ||||||
|     elif config.arch == 'simres': |                     config.module, config.depth, config.class_num | ||||||
|       from .shape_searchs import SearchWidthSimResNet |                 ) | ||||||
|       if config.search_mode == 'width': |             elif config.search_mode == "shape": | ||||||
|         return SearchWidthSimResNet(config.depth, config.class_num) |                 return SearchShapeCifarResNet( | ||||||
|       else: raise ValueError('invalid search mode : {:}'.format(config.search_mode)) |                     config.module, config.depth, config.class_num | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|  |                 raise ValueError("invalid search mode : {:}".format(config.search_mode)) | ||||||
|  |         elif config.arch == "simres": | ||||||
|  |             from .shape_searchs import SearchWidthSimResNet | ||||||
|  |  | ||||||
|  |             if config.search_mode == "width": | ||||||
|  |                 return SearchWidthSimResNet(config.depth, config.class_num) | ||||||
|  |             else: | ||||||
|  |                 raise ValueError("invalid search mode : {:}".format(config.search_mode)) | ||||||
|  |         else: | ||||||
|  |             raise ValueError( | ||||||
|  |                 "invalid arch : {:} for dataset [{:}]".format( | ||||||
|  |                     config.arch, config.dataset | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |     elif config.dataset == "imagenet": | ||||||
|  |         from .shape_searchs import SearchShapeImagenetResNet | ||||||
|  |  | ||||||
|  |         assert config.search_mode == "shape", "invalid search-mode : {:}".format( | ||||||
|  |             config.search_mode | ||||||
|  |         ) | ||||||
|  |         if config.arch == "resnet": | ||||||
|  |             return SearchShapeImagenetResNet( | ||||||
|  |                 config.block_name, config.layers, config.deep_stem, config.class_num | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             raise ValueError("invalid model config : {:}".format(config)) | ||||||
|     else: |     else: | ||||||
|       raise ValueError('invalid arch : {:} for dataset [{:}]'.format(config.arch, config.dataset)) |         raise ValueError("invalid dataset in the model config : {:}".format(config)) | ||||||
|   elif config.dataset == 'imagenet': |  | ||||||
|     from .shape_searchs import SearchShapeImagenetResNet |  | ||||||
|     assert config.search_mode == 'shape', 'invalid search-mode : {:}'.format( config.search_mode ) |  | ||||||
|     if config.arch == 'resnet': |  | ||||||
|       return SearchShapeImagenetResNet(config.block_name, config.layers, config.deep_stem, config.class_num) |  | ||||||
|     else: |  | ||||||
|       raise ValueError('invalid model config : {:}'.format(config)) |  | ||||||
|   else: |  | ||||||
|     raise ValueError('invalid dataset in the model config : {:}'.format(config)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_net_from_checkpoint(checkpoint): | def load_net_from_checkpoint(checkpoint): | ||||||
|   assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint) |     assert osp.isfile(checkpoint), "checkpoint {:} does not exist".format(checkpoint) | ||||||
|   checkpoint   = torch.load(checkpoint) |     checkpoint = torch.load(checkpoint) | ||||||
|   model_config = dict2config(checkpoint['model-config'], None) |     model_config = dict2config(checkpoint["model-config"], None) | ||||||
|   model        = obtain_model(model_config) |     model = obtain_model(model_config) | ||||||
|   model.load_state_dict(checkpoint['base-model']) |     model.load_state_dict(checkpoint["base-model"]) | ||||||
|   return model |     return model | ||||||
|   | |||||||
| @@ -21,8 +21,12 @@ def get_model(config: Dict[Text, Any], **kwargs): | |||||||
|         act_cls = super_name2activation[kwargs["act_cls"]] |         act_cls = super_name2activation[kwargs["act_cls"]] | ||||||
|         norm_cls = super_name2norm[kwargs["norm_cls"]] |         norm_cls = super_name2norm[kwargs["norm_cls"]] | ||||||
|         mean, std = kwargs.get("mean", None), kwargs.get("std", None) |         mean, std = kwargs.get("mean", None), kwargs.get("std", None) | ||||||
|         hidden_dim1 = kwargs.get("hidden_dim1", 200) |         if "hidden_dim" in kwargs: | ||||||
|         hidden_dim2 = kwargs.get("hidden_dim2", 100) |             hidden_dim1 = kwargs.get("hidden_dim") | ||||||
|  |             hidden_dim2 = kwargs.get("hidden_dim") | ||||||
|  |         else: | ||||||
|  |             hidden_dim1 = kwargs.get("hidden_dim1", 200) | ||||||
|  |             hidden_dim2 = kwargs.get("hidden_dim2", 100) | ||||||
|         model = SuperSequential( |         model = SuperSequential( | ||||||
|             norm_cls(mean=mean, std=std), |             norm_cls(mean=mean, std=std), | ||||||
|             SuperLinear(kwargs["input_dim"], hidden_dim1), |             SuperLinear(kwargs["input_dim"], hidden_dim1), | ||||||
| @@ -34,4 +38,3 @@ def get_model(config: Dict[Text, Any], **kwargs): | |||||||
|     else: |     else: | ||||||
|         raise TypeError("Unkonwn model type: {:}".format(model_type)) |         raise TypeError("Unkonwn model type: {:}".format(model_type)) | ||||||
|     return model |     return model | ||||||
|  |  | ||||||
|   | |||||||
| @@ -59,6 +59,9 @@ class TensorContainer: | |||||||
|         for tensor in self._tensors: |         for tensor in self._tensors: | ||||||
|             tensor.requires_grad_(requires_grad) |             tensor.requires_grad_(requires_grad) | ||||||
|  |  | ||||||
|  |     def parameters(self): | ||||||
|  |         return self._tensors | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def tensors(self): |     def tensors(self): | ||||||
|         return self._tensors |         return self._tensors | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user