Prototype MAML
This commit is contained in:
		| @@ -23,6 +23,9 @@ from datasets.synthetic_core import get_synthetic_env | ||||
| from models.xcore import get_model | ||||
|  | ||||
|  | ||||
| from lfna_utils import lfna_setup | ||||
|  | ||||
|  | ||||
| def subsample(historical_x, historical_y, maxn=10000): | ||||
|     total = historical_x.size(0) | ||||
|     if total <= maxn: | ||||
| @@ -33,24 +36,7 @@ def subsample(historical_x, historical_y, maxn=10000): | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     cache_path = ( | ||||
|         logger.path(None) / ".." / "env-{:}-info.pth".format(args.env_version) | ||||
|     ).resolve() | ||||
|     if cache_path.exists(): | ||||
|         env_info = torch.load(cache_path) | ||||
|     else: | ||||
|         env_info = dict() | ||||
|         dynamic_env = get_synthetic_env(version=args.env_version) | ||||
|         env_info["total"] = len(dynamic_env) | ||||
|         for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): | ||||
|             env_info["{:}-timestamp".format(idx)] = timestamp | ||||
|             env_info["{:}-x".format(idx)] = _allx | ||||
|             env_info["{:}-y".format(idx)] = _ally | ||||
|         env_info["dynamic_env"] = dynamic_env | ||||
|         torch.save(env_info, cache_path) | ||||
|     logger, env_info = lfna_setup(args) | ||||
|  | ||||
|     # check indexes to be evaluated | ||||
|     to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None) | ||||
| @@ -60,6 +46,8 @@ def main(args): | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     w_container_per_epoch = dict() | ||||
|  | ||||
|     per_timestamp_time, start_time = AverageMeter(), time.time() | ||||
|     for i, idx in enumerate(to_evaluate_indexes): | ||||
|  | ||||
| @@ -89,9 +77,6 @@ def main(args): | ||||
|             output_dim=1, | ||||
|             act_cls="leaky_relu", | ||||
|             norm_cls="identity", | ||||
|             # norm_cls="simple_norm", | ||||
|             # mean=mean, | ||||
|             # std=std, | ||||
|         ) | ||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
|         # build optimizer | ||||
| @@ -144,6 +129,7 @@ def main(args): | ||||
|         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( | ||||
|             idx, env_info["total"] | ||||
|         ) | ||||
|         w_container_per_epoch[idx] = model.get_w_container().no_grad_clone() | ||||
|         save_checkpoint( | ||||
|             { | ||||
|                 "model_state_dict": model.state_dict(), | ||||
| @@ -155,10 +141,14 @@ def main(args): | ||||
|             logger, | ||||
|         ) | ||||
|         logger.log("") | ||||
|  | ||||
|         per_timestamp_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     save_checkpoint( | ||||
|         {"w_container_per_epoch": w_container_per_epoch}, | ||||
|         logger.path(None) / "final-ckp.pth", | ||||
|         logger, | ||||
|     ) | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|  | ||||
| @@ -210,5 +200,7 @@ if __name__ == "__main__": | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "The save dir argument can not be None" | ||||
|     args.save_dir = "{:}-{:}".format(args.save_dir, args.env_version) | ||||
|     args.save_dir = "{:}-{:}-d{:}".format( | ||||
|         args.save_dir, args.env_version, args.hidden_dim | ||||
|     ) | ||||
|     main(args) | ||||
|   | ||||
							
								
								
									
										220
									
								
								exps/LFNA/basic-maml.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										220
									
								
								exps/LFNA/basic-maml.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,220 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/LFNA/basic-maml.py --env_version v1   # | ||||
| # python exps/LFNA/basic-maml.py --env_version v2   # | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
| from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||
| from log_utils import time_string | ||||
| from log_utils import AverageMeter, convert_secs2time | ||||
|  | ||||
| from utils import split_str2indexes | ||||
|  | ||||
| from procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||
| from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||
| from datasets.synthetic_core import get_synthetic_env | ||||
| from models.xcore import get_model | ||||
| from xlayers import super_core | ||||
|  | ||||
| from lfna_utils import lfna_setup, TimeData | ||||
|  | ||||
|  | ||||
| class MAML: | ||||
|     """A LFNA meta-model that uses the MLP as delta-net.""" | ||||
|  | ||||
|     def __init__(self, container, criterion, meta_lr, inner_lr=0.01, inner_step=1): | ||||
|         self.criterion = criterion | ||||
|         self.container = container | ||||
|         self.meta_optimizer = torch.optim.Adam( | ||||
|             self.container.parameters(), lr=meta_lr, amsgrad=True | ||||
|         ) | ||||
|         self.inner_lr = inner_lr | ||||
|         self.inner_step = inner_step | ||||
|  | ||||
|     def adapt(self, model, dataset): | ||||
|         # create a container for the future timestamp | ||||
|         y_hat = model.forward_with_container(dataset.x, self.container) | ||||
|         loss = self.criterion(y_hat, dataset.y) | ||||
|         grads = torch.autograd.grad(loss, self.container.parameters()) | ||||
|  | ||||
|         fast_container = self.container.additive( | ||||
|             [-self.inner_lr * grad for grad in grads] | ||||
|         ) | ||||
|         import pdb | ||||
|  | ||||
|         pdb.set_trace() | ||||
|         w_container.requires_grad_(True) | ||||
|         containers = [w_container] | ||||
|         for idx, dataset in enumerate(seq_datasets): | ||||
|             x, y = dataset.x, dataset.y | ||||
|             y_hat = model.forward_with_container(x, containers[-1]) | ||||
|             loss = criterion(y_hat, y) | ||||
|             gradients = torch.autograd.grad(loss, containers[-1].tensors) | ||||
|             with torch.no_grad(): | ||||
|                 flatten_w = containers[-1].flatten().view(-1, 1) | ||||
|                 flatten_g = containers[-1].flatten(gradients).view(-1, 1) | ||||
|                 input_statistics = torch.tensor([x.mean(), x.std()]).view(1, 2) | ||||
|                 input_statistics = input_statistics.expand(flatten_w.numel(), -1) | ||||
|             delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) | ||||
|             delta = self.delta_net(delta_inputs).view(-1) | ||||
|             delta = torch.clamp(delta, -0.5, 0.5) | ||||
|             unflatten_delta = containers[-1].unflatten(delta) | ||||
|             future_container = containers[-1].no_grad_clone().additive(unflatten_delta) | ||||
|             # future_container = containers[-1].additive(unflatten_delta) | ||||
|             containers.append(future_container) | ||||
|         # containers = containers[1:] | ||||
|         meta_loss = [] | ||||
|         temp_containers = [] | ||||
|         for idx, dataset in enumerate(seq_datasets): | ||||
|             if idx == 0: | ||||
|                 continue | ||||
|             current_container = containers[idx] | ||||
|             y_hat = model.forward_with_container(dataset.x, current_container) | ||||
|             loss = criterion(y_hat, dataset.y) | ||||
|             meta_loss.append(loss) | ||||
|             temp_containers.append((dataset.timestamp, current_container, -loss.item())) | ||||
|         meta_loss = sum(meta_loss) | ||||
|         w_container.requires_grad_(False) | ||||
|         # meta_loss.backward() | ||||
|         # self.meta_optimizer.step() | ||||
|         return meta_loss, temp_containers | ||||
|  | ||||
|     def step(self): | ||||
|         torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0) | ||||
|         self.meta_optimizer.step() | ||||
|  | ||||
|     def zero_grad(self): | ||||
|         self.meta_optimizer.zero_grad() | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     logger, env_info = lfna_setup(args) | ||||
|  | ||||
|     total_time = env_info["total"] | ||||
|     for i in range(total_time): | ||||
|         for xkey in ("timestamp", "x", "y"): | ||||
|             nkey = "{:}-{:}".format(i, xkey) | ||||
|             assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) | ||||
|     train_time_bar = total_time // 2 | ||||
|     base_model = get_model( | ||||
|         dict(model_type="simple_mlp"), | ||||
|         act_cls="leaky_relu", | ||||
|         norm_cls="identity", | ||||
|         input_dim=1, | ||||
|         output_dim=1, | ||||
|     ) | ||||
|  | ||||
|     w_container = base_model.get_w_container() | ||||
|     criterion = torch.nn.MSELoss() | ||||
|     print("There are {:} weights.".format(w_container.numel())) | ||||
|  | ||||
|     maml = MAML(w_container, criterion, args.meta_lr, args.inner_lr, args.inner_step) | ||||
|  | ||||
|     # meta-training | ||||
|     per_epoch_time, start_time = AverageMeter(), time.time() | ||||
|     for iepoch in range(args.epochs): | ||||
|  | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) | ||||
|             + need_time | ||||
|         ) | ||||
|  | ||||
|         maml.zero_grad() | ||||
|  | ||||
|         all_meta_losses = [] | ||||
|         for ibatch in range(args.meta_batch): | ||||
|             sampled_timestamp = random.randint(0, train_time_bar) | ||||
|             past_dataset = TimeData( | ||||
|                 sampled_timestamp, | ||||
|                 env_info["{:}-x".format(sampled_timestamp)], | ||||
|                 env_info["{:}-y".format(sampled_timestamp)], | ||||
|             ) | ||||
|             future_dataset = TimeData( | ||||
|                 sampled_timestamp + 1, | ||||
|                 env_info["{:}-x".format(sampled_timestamp + 1)], | ||||
|                 env_info["{:}-y".format(sampled_timestamp + 1)], | ||||
|             ) | ||||
|             maml.adapt(base_model, past_dataset) | ||||
|             import pdb | ||||
|  | ||||
|             pdb.set_trace() | ||||
|  | ||||
|         meta_loss = torch.stack(all_meta_losses).mean() | ||||
|         meta_loss.backward() | ||||
|         adaptor.step() | ||||
|  | ||||
|         debug_str = pool.debug_info(debug_timestamp) | ||||
|         logger.log("meta-loss: {:.4f}".format(meta_loss.item())) | ||||
|  | ||||
|         per_epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Use the data in the past.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/lfna-synthetic/maml", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--env_version", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="The synthetic enviornment version.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--meta_lr", | ||||
|         type=float, | ||||
|         default=0.01, | ||||
|         help="The learning rate for the MAML optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--inner_lr", | ||||
|         type=float, | ||||
|         default=0.01, | ||||
|         help="The learning rate for the inner optimization", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--inner_step", type=int, default=1, help="The inner loop steps for MAML." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--meta_batch", | ||||
|         type=int, | ||||
|         default=5, | ||||
|         help="The batch size for the meta-model", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--epochs", | ||||
|         type=int, | ||||
|         default=1000, | ||||
|         help="The total number of epochs.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=4, | ||||
|         help="The number of data loading workers (default: 4)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "The save dir argument can not be None" | ||||
|     main(args) | ||||
| @@ -1,7 +1,8 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/LFNA/basic-same.py --srange 1-999 | ||||
| # python exps/LFNA/basic-same.py --srange 1-999 --env_version v1 --hidden_dim 16 | ||||
| # python exps/LFNA/basic-same.py --srange 1-999 --env_version v2 --hidden_dim | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| @@ -22,6 +23,8 @@ from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||
| from datasets.synthetic_core import get_synthetic_env | ||||
| from models.xcore import get_model | ||||
|  | ||||
| from lfna_utils import lfna_setup | ||||
|  | ||||
|  | ||||
| def subsample(historical_x, historical_y, maxn=10000): | ||||
|     total = historical_x.size(0) | ||||
| @@ -33,22 +36,7 @@ def subsample(historical_x, historical_y, maxn=10000): | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     cache_path = (logger.path(None) / ".." / "env-info.pth").resolve() | ||||
|     if cache_path.exists(): | ||||
|         env_info = torch.load(cache_path) | ||||
|     else: | ||||
|         env_info = dict() | ||||
|         dynamic_env = get_synthetic_env() | ||||
|         env_info["total"] = len(dynamic_env) | ||||
|         for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): | ||||
|             env_info["{:}-timestamp".format(idx)] = timestamp | ||||
|             env_info["{:}-x".format(idx)] = _allx | ||||
|             env_info["{:}-y".format(idx)] = _ally | ||||
|         env_info["dynamic_env"] = dynamic_env | ||||
|         torch.save(env_info, cache_path) | ||||
|     logger, env_info, model_kwargs = lfna_setup(args) | ||||
|  | ||||
|     # check indexes to be evaluated | ||||
|     to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None) | ||||
| @@ -78,16 +66,6 @@ def main(args): | ||||
|         historical_x = env_info["{:}-x".format(idx)] | ||||
|         historical_y = env_info["{:}-y".format(idx)] | ||||
|         # build model | ||||
|         mean, std = historical_x.mean().item(), historical_x.std().item() | ||||
|         model_kwargs = dict( | ||||
|             input_dim=1, | ||||
|             output_dim=1, | ||||
|             act_cls="leaky_relu", | ||||
|             norm_cls="identity", | ||||
|             # norm_cls="simple_norm", | ||||
|             # mean=mean, | ||||
|             # std=std, | ||||
|         ) | ||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
|         # build optimizer | ||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||
| @@ -151,9 +129,9 @@ def main(args): | ||||
|             logger, | ||||
|         ) | ||||
|         logger.log("") | ||||
|  | ||||
|         per_timestamp_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     save_checkpoint( | ||||
|         {"w_container_per_epoch": w_container_per_epoch}, | ||||
|         logger.path(None) / "final-ckp.pth", | ||||
| @@ -172,6 +150,18 @@ if __name__ == "__main__": | ||||
|         default="./outputs/lfna-synthetic/use-same-timestamp", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--env_version", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="The synthetic enviornment version.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--hidden_dim", | ||||
|         type=int, | ||||
|         required=True, | ||||
|         help="The hidden dimension.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--init_lr", | ||||
|         type=float, | ||||
| @@ -205,4 +195,7 @@ if __name__ == "__main__": | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "The save dir argument can not be None" | ||||
|     args.save_dir = "{:}-{:}-d{:}".format( | ||||
|         args.save_dir, args.env_version, args.hidden_dim | ||||
|     ) | ||||
|     main(args) | ||||
|   | ||||
							
								
								
									
										272
									
								
								exps/LFNA/lfna-v0.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										272
									
								
								exps/LFNA/lfna-v0.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,272 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| # python exps/LFNA/lfna-v0.py | ||||
| ##################################################### | ||||
| import sys, time, copy, torch, random, argparse | ||||
| from tqdm import tqdm | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
| from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||
| from log_utils import time_string | ||||
| from log_utils import AverageMeter, convert_secs2time | ||||
|  | ||||
| from utils import split_str2indexes | ||||
|  | ||||
| from procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||
| from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||
| from datasets.synthetic_core import get_synthetic_env | ||||
| from models.xcore import get_model | ||||
| from xlayers import super_core | ||||
|  | ||||
|  | ||||
| class LFNAmlp: | ||||
|     """A LFNA meta-model that uses the MLP as delta-net.""" | ||||
|  | ||||
|     def __init__(self, obs_dim, hidden_sizes, act_name): | ||||
|         self.delta_net = super_core.SuperSequential( | ||||
|             super_core.SuperLinear(obs_dim, hidden_sizes[0]), | ||||
|             super_core.super_name2activation[act_name](), | ||||
|             super_core.SuperLinear(hidden_sizes[0], hidden_sizes[1]), | ||||
|             super_core.super_name2activation[act_name](), | ||||
|             super_core.SuperLinear(hidden_sizes[1], 1), | ||||
|         ) | ||||
|         self.meta_optimizer = torch.optim.Adam( | ||||
|             self.delta_net.parameters(), lr=0.01, amsgrad=True | ||||
|         ) | ||||
|  | ||||
|     def adapt(self, model, criterion, w_container, seq_datasets): | ||||
|         w_container.requires_grad_(True) | ||||
|         containers = [w_container] | ||||
|         for idx, dataset in enumerate(seq_datasets): | ||||
|             x, y = dataset.x, dataset.y | ||||
|             y_hat = model.forward_with_container(x, containers[-1]) | ||||
|             loss = criterion(y_hat, y) | ||||
|             gradients = torch.autograd.grad(loss, containers[-1].tensors) | ||||
|             with torch.no_grad(): | ||||
|                 flatten_w = containers[-1].flatten().view(-1, 1) | ||||
|                 flatten_g = containers[-1].flatten(gradients).view(-1, 1) | ||||
|                 input_statistics = torch.tensor([x.mean(), x.std()]).view(1, 2) | ||||
|                 input_statistics = input_statistics.expand(flatten_w.numel(), -1) | ||||
|             delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) | ||||
|             delta = self.delta_net(delta_inputs).view(-1) | ||||
|             delta = torch.clamp(delta, -0.5, 0.5) | ||||
|             unflatten_delta = containers[-1].unflatten(delta) | ||||
|             future_container = containers[-1].no_grad_clone().additive(unflatten_delta) | ||||
|             # future_container = containers[-1].additive(unflatten_delta) | ||||
|             containers.append(future_container) | ||||
|         # containers = containers[1:] | ||||
|         meta_loss = [] | ||||
|         temp_containers = [] | ||||
|         for idx, dataset in enumerate(seq_datasets): | ||||
|             if idx == 0: | ||||
|                 continue | ||||
|             current_container = containers[idx] | ||||
|             y_hat = model.forward_with_container(dataset.x, current_container) | ||||
|             loss = criterion(y_hat, dataset.y) | ||||
|             meta_loss.append(loss) | ||||
|             temp_containers.append((dataset.timestamp, current_container, -loss.item())) | ||||
|         meta_loss = sum(meta_loss) | ||||
|         w_container.requires_grad_(False) | ||||
|         # meta_loss.backward() | ||||
|         # self.meta_optimizer.step() | ||||
|         return meta_loss, temp_containers | ||||
|  | ||||
|     def step(self): | ||||
|         torch.nn.utils.clip_grad_norm_(self.delta_net.parameters(), 1.0) | ||||
|         self.meta_optimizer.step() | ||||
|  | ||||
|     def zero_grad(self): | ||||
|         self.meta_optimizer.zero_grad() | ||||
|         self.delta_net.zero_grad() | ||||
|  | ||||
|  | ||||
| class TimeData: | ||||
|     def __init__(self, timestamp, xs, ys): | ||||
|         self._timestamp = timestamp | ||||
|         self._xs = xs | ||||
|         self._ys = ys | ||||
|  | ||||
|     @property | ||||
|     def x(self): | ||||
|         return self._xs | ||||
|  | ||||
|     @property | ||||
|     def y(self): | ||||
|         return self._ys | ||||
|  | ||||
|     @property | ||||
|     def timestamp(self): | ||||
|         return self._timestamp | ||||
|  | ||||
|  | ||||
| class Population: | ||||
|     """A population used to maintain models at different timestamps.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self._time2model = dict() | ||||
|         self._time2score = dict()  # higher is better | ||||
|  | ||||
|     def append(self, timestamp, model, score): | ||||
|         if timestamp in self._time2model: | ||||
|             if self._time2score[timestamp] > score: | ||||
|                 return | ||||
|         self._time2model[timestamp] = model.no_grad_clone() | ||||
|         self._time2score[timestamp] = score | ||||
|  | ||||
|     def query(self, timestamp): | ||||
|         closet_timestamp = None | ||||
|         for xtime, model in self._time2model.items(): | ||||
|             if closet_timestamp is None or ( | ||||
|                 xtime < timestamp and timestamp - closet_timestamp >= timestamp - xtime | ||||
|             ): | ||||
|                 closet_timestamp = xtime | ||||
|         return self._time2model[closet_timestamp], closet_timestamp | ||||
|  | ||||
|     def debug_info(self, timestamps): | ||||
|         xstrs = [] | ||||
|         for timestamp in timestamps: | ||||
|             if timestamp in self._time2score: | ||||
|                 xstrs.append( | ||||
|                     "{:04d}: {:.4f}".format(timestamp, self._time2score[timestamp]) | ||||
|                 ) | ||||
|         return ", ".join(xstrs) | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     cache_path = (logger.path(None) / ".." / "env-info.pth").resolve() | ||||
|     if cache_path.exists(): | ||||
|         env_info = torch.load(cache_path) | ||||
|     else: | ||||
|         env_info = dict() | ||||
|         dynamic_env = get_synthetic_env() | ||||
|         env_info["total"] = len(dynamic_env) | ||||
|         for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): | ||||
|             env_info["{:}-timestamp".format(idx)] = timestamp | ||||
|             env_info["{:}-x".format(idx)] = _allx | ||||
|             env_info["{:}-y".format(idx)] = _ally | ||||
|         env_info["dynamic_env"] = dynamic_env | ||||
|         torch.save(env_info, cache_path) | ||||
|  | ||||
|     total_time = env_info["total"] | ||||
|     for i in range(total_time): | ||||
|         for xkey in ("timestamp", "x", "y"): | ||||
|             nkey = "{:}-{:}".format(i, xkey) | ||||
|             assert nkey in env_info, "{:} no in {:}".format(nkey, list(env_info.keys())) | ||||
|     train_time_bar = total_time // 2 | ||||
|     base_model = get_model( | ||||
|         dict(model_type="simple_mlp"), | ||||
|         act_cls="leaky_relu", | ||||
|         norm_cls="identity", | ||||
|         input_dim=1, | ||||
|         output_dim=1, | ||||
|     ) | ||||
|  | ||||
|     w_container = base_model.get_w_container() | ||||
|     criterion = torch.nn.MSELoss() | ||||
|     print("There are {:} weights.".format(w_container.numel())) | ||||
|  | ||||
|     adaptor = LFNAmlp(4, (50, 20), "leaky_relu") | ||||
|  | ||||
|     pool = Population() | ||||
|     pool.append(0, w_container, -100) | ||||
|  | ||||
|     # LFNA meta-training | ||||
|     per_epoch_time, start_time = AverageMeter(), time.time() | ||||
|     for iepoch in range(args.epochs): | ||||
|  | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) | ||||
|             + need_time | ||||
|         ) | ||||
|  | ||||
|         adaptor.zero_grad() | ||||
|  | ||||
|         debug_timestamp = set() | ||||
|         all_meta_losses = [] | ||||
|         for ibatch in range(args.meta_batch): | ||||
|             sampled_timestamp = random.randint(0, train_time_bar) | ||||
|             query_w_container, query_timestamp = pool.query(sampled_timestamp) | ||||
|             # def adapt(self, model, w_container, xs, ys): | ||||
|             seq_datasets = [] | ||||
|             # xs, ys = [], [] | ||||
|             for it in range(sampled_timestamp, sampled_timestamp + args.max_seq): | ||||
|                 xs = env_info["{:}-x".format(it)] | ||||
|                 ys = env_info["{:}-y".format(it)] | ||||
|                 seq_datasets.append(TimeData(it, xs, ys)) | ||||
|             temp_meta_loss, temp_containers = adaptor.adapt( | ||||
|                 base_model, criterion, query_w_container, seq_datasets | ||||
|             ) | ||||
|             all_meta_losses.append(temp_meta_loss) | ||||
|             for temp_time, temp_container, temp_score in temp_containers: | ||||
|                 pool.append(temp_time, temp_container, temp_score) | ||||
|                 debug_timestamp.add(temp_time) | ||||
|         meta_loss = torch.stack(all_meta_losses).mean() | ||||
|         meta_loss.backward() | ||||
|         adaptor.step() | ||||
|  | ||||
|         debug_str = pool.debug_info(debug_timestamp) | ||||
|         logger.log("meta-loss: {:.4f}".format(meta_loss.item())) | ||||
|  | ||||
|         per_epoch_time.update(time.time() - start_time) | ||||
|         start_time = time.time() | ||||
|  | ||||
|     logger.log("-" * 200 + "\n") | ||||
|     logger.close() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Use the data in the past.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/lfna-synthetic/lfna-v1", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--init_lr", | ||||
|         type=float, | ||||
|         default=0.1, | ||||
|         help="The initial learning rate for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--meta_batch", | ||||
|         type=int, | ||||
|         default=5, | ||||
|         help="The batch size for the meta-model", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--epochs", | ||||
|         type=int, | ||||
|         default=1000, | ||||
|         help="The total number of epochs.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--max_seq", | ||||
|         type=int, | ||||
|         default=5, | ||||
|         help="The maximum length of the sequence.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|         default=4, | ||||
|         help="The number of data loading workers (default: 4)", | ||||
|     ) | ||||
|     # Random Seed | ||||
|     parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed") | ||||
|     args = parser.parse_args() | ||||
|     if args.rand_seed is None or args.rand_seed < 0: | ||||
|         args.rand_seed = random.randint(1, 100000) | ||||
|     assert args.save_dir is not None, "The save dir argument can not be None" | ||||
|     main(args) | ||||
							
								
								
									
										61
									
								
								exps/LFNA/lfna_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								exps/LFNA/lfna_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 # | ||||
| ##################################################### | ||||
| import torch | ||||
| from tqdm import tqdm | ||||
| from procedures import prepare_seed, prepare_logger | ||||
| from datasets.synthetic_core import get_synthetic_env | ||||
|  | ||||
|  | ||||
| def lfna_setup(args): | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|     cache_path = ( | ||||
|         logger.path(None) / ".." / "env-{:}-info.pth".format(args.env_version) | ||||
|     ).resolve() | ||||
|     if cache_path.exists(): | ||||
|         env_info = torch.load(cache_path) | ||||
|     else: | ||||
|         env_info = dict() | ||||
|         dynamic_env = get_synthetic_env(version=args.env_version) | ||||
|         env_info["total"] = len(dynamic_env) | ||||
|         for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)): | ||||
|             env_info["{:}-timestamp".format(idx)] = timestamp | ||||
|             env_info["{:}-x".format(idx)] = _allx | ||||
|             env_info["{:}-y".format(idx)] = _ally | ||||
|         env_info["dynamic_env"] = dynamic_env | ||||
|         torch.save(env_info, cache_path) | ||||
|  | ||||
|     model_kwargs = dict( | ||||
|         input_dim=1, | ||||
|         output_dim=1, | ||||
|         hidden_dim=args.hidden_dim, | ||||
|         act_cls="leaky_relu", | ||||
|         norm_cls="identity", | ||||
|     ) | ||||
|     return logger, env_info, model_kwargs | ||||
|  | ||||
|  | ||||
| class TimeData: | ||||
|     def __init__(self, timestamp, xs, ys): | ||||
|         self._timestamp = timestamp | ||||
|         self._xs = xs | ||||
|         self._ys = ys | ||||
|  | ||||
|     @property | ||||
|     def x(self): | ||||
|         return self._xs | ||||
|  | ||||
|     @property | ||||
|     def y(self): | ||||
|         return self._ys | ||||
|  | ||||
|     @property | ||||
|     def timestamp(self): | ||||
|         return self._timestamp | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(timestamp={:}, with {num} samples)".format( | ||||
|             name=self.__class__.__name__, timestamp=self._timestamp, num=len(self._xs) | ||||
|         ) | ||||
| @@ -14,7 +14,9 @@ class Bottleneck(nn.Module): | ||||
|         self.bn1 = nn.BatchNorm2d(nChannels) | ||||
|         self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) | ||||
|         self.bn2 = nn.BatchNorm2d(interChannels) | ||||
|     self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False) | ||||
|         self.conv2 = nn.Conv2d( | ||||
|             interChannels, growthRate, kernel_size=3, padding=1, bias=False | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         out = self.conv1(F.relu(self.bn1(x))) | ||||
| @@ -27,7 +29,9 @@ class SingleLayer(nn.Module): | ||||
|     def __init__(self, nChannels, growthRate): | ||||
|         super(SingleLayer, self).__init__() | ||||
|         self.bn1 = nn.BatchNorm2d(nChannels) | ||||
|     self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False) | ||||
|         self.conv1 = nn.Conv2d( | ||||
|             nChannels, growthRate, kernel_size=3, padding=1, bias=False | ||||
|         ) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         out = self.conv1(F.relu(self.bn1(x))) | ||||
| @@ -51,10 +55,18 @@ class DenseNet(nn.Module): | ||||
|     def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): | ||||
|         super(DenseNet, self).__init__() | ||||
|  | ||||
|     if bottleneck:  nDenseBlocks = int( (depth-4) / 6 ) | ||||
|     else         :  nDenseBlocks = int( (depth-4) / 3 ) | ||||
|         if bottleneck: | ||||
|             nDenseBlocks = int((depth - 4) / 6) | ||||
|         else: | ||||
|             nDenseBlocks = int((depth - 4) / 3) | ||||
|  | ||||
|     self.message = 'CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}'.format('bottleneck' if bottleneck else 'basic', depth, reduction, growthRate, nClasses) | ||||
|         self.message = "CifarDenseNet : block : {:}, depth : {:}, reduction : {:}, growth-rate = {:}, class = {:}".format( | ||||
|             "bottleneck" if bottleneck else "basic", | ||||
|             depth, | ||||
|             reduction, | ||||
|             growthRate, | ||||
|             nClasses, | ||||
|         ) | ||||
|  | ||||
|         nChannels = 2 * growthRate | ||||
|         self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False) | ||||
| @@ -75,8 +87,8 @@ class DenseNet(nn.Module): | ||||
|         nChannels += nDenseBlocks * growthRate | ||||
|  | ||||
|         self.act = nn.Sequential( | ||||
|                   nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), | ||||
|                   nn.AvgPool2d(8)) | ||||
|             nn.BatchNorm2d(nChannels), nn.ReLU(inplace=True), nn.AvgPool2d(8) | ||||
|         ) | ||||
|         self.fc = nn.Linear(nChannels, nClasses) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|   | ||||
| @@ -6,10 +6,11 @@ from .SharedUtils    import additive_func | ||||
|  | ||||
|  | ||||
| class Downsample(nn.Module): | ||||
|  | ||||
|     def __init__(self, nIn, nOut, stride): | ||||
|         super(Downsample, self).__init__() | ||||
|     assert stride == 2 and nOut == 2*nIn, 'stride:{} IO:{},{}'.format(stride, nIn, nOut) | ||||
|         assert stride == 2 and nOut == 2 * nIn, "stride:{} IO:{},{}".format( | ||||
|             stride, nIn, nOut | ||||
|         ) | ||||
|         self.in_dim = nIn | ||||
|         self.out_dim = nOut | ||||
|         self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | ||||
| @@ -22,28 +23,34 @@ class Downsample(nn.Module): | ||||
|  | ||||
|  | ||||
| class ConvBNReLU(nn.Module): | ||||
|    | ||||
|     def __init__(self, nIn, nOut, kernel, stride, padding, bias, relu): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|     self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, bias=bias) | ||||
|         self.conv = nn.Conv2d( | ||||
|             nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, bias=bias | ||||
|         ) | ||||
|         self.bn = nn.BatchNorm2d(nOut) | ||||
|     if relu: self.relu = nn.ReLU(inplace=True) | ||||
|     else   : self.relu = None | ||||
|         if relu: | ||||
|             self.relu = nn.ReLU(inplace=True) | ||||
|         else: | ||||
|             self.relu = None | ||||
|         self.out_dim = nOut | ||||
|         self.num_conv = 1 | ||||
|  | ||||
|     def forward(self, x): | ||||
|         conv = self.conv(x) | ||||
|         bn = self.bn(conv) | ||||
|     if self.relu: return self.relu( bn ) | ||||
|     else        : return bn | ||||
|         if self.relu: | ||||
|             return self.relu(bn) | ||||
|         else: | ||||
|             return bn | ||||
|  | ||||
|  | ||||
| class ResNetBasicblock(nn.Module): | ||||
|     expansion = 1 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBasicblock, self).__init__() | ||||
|     assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_a = ConvBNReLU(inplanes, planes, 3, stride, 1, False, True) | ||||
|         self.conv_b = ConvBNReLU(planes, planes, 3, 1, 1, False, False) | ||||
|         if stride == 2: | ||||
| @@ -68,19 +75,23 @@ class ResNetBasicblock(nn.Module): | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
|  | ||||
| class ResNetBottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|  | ||||
|     def __init__(self, inplanes, planes, stride): | ||||
|         super(ResNetBottleneck, self).__init__() | ||||
|     assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) | ||||
|         assert stride == 1 or stride == 2, "invalid stride {:}".format(stride) | ||||
|         self.conv_1x1 = ConvBNReLU(inplanes, planes, 1, 1, 0, False, True) | ||||
|         self.conv_3x3 = ConvBNReLU(planes, planes, 3, stride, 1, False, True) | ||||
|     self.conv_1x4 = ConvBNReLU(planes, planes*self.expansion, 1, 1, 0, False, False) | ||||
|         self.conv_1x4 = ConvBNReLU( | ||||
|             planes, planes * self.expansion, 1, 1, 0, False, False | ||||
|         ) | ||||
|         if stride == 2: | ||||
|             self.downsample = Downsample(inplanes, planes * self.expansion, stride) | ||||
|         elif inplanes != planes * self.expansion: | ||||
|       self.downsample = ConvBNReLU(inplanes, planes*self.expansion, 1, 1, 0, False, False) | ||||
|             self.downsample = ConvBNReLU( | ||||
|                 inplanes, planes * self.expansion, 1, 1, 0, False, False | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|         self.out_dim = planes * self.expansion | ||||
| @@ -100,25 +111,25 @@ class ResNetBottleneck(nn.Module): | ||||
|         return F.relu(out, inplace=True) | ||||
|  | ||||
|  | ||||
|  | ||||
| class CifarResNet(nn.Module): | ||||
|  | ||||
|     def __init__(self, block_name, depth, num_classes, zero_init_residual): | ||||
|         super(CifarResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|     if block_name == 'ResNetBasicblock': | ||||
|         if block_name == "ResNetBasicblock": | ||||
|             block = ResNetBasicblock | ||||
|       assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' | ||||
|             assert (depth - 2) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|             layer_blocks = (depth - 2) // 6 | ||||
|     elif block_name == 'ResNetBottleneck': | ||||
|         elif block_name == "ResNetBottleneck": | ||||
|             block = ResNetBottleneck | ||||
|       assert (depth - 2) % 9 == 0, 'depth should be one of 164' | ||||
|             assert (depth - 2) % 9 == 0, "depth should be one of 164" | ||||
|             layer_blocks = (depth - 2) // 9 | ||||
|         else: | ||||
|       raise ValueError('invalid block : {:}'.format(block_name)) | ||||
|             raise ValueError("invalid block : {:}".format(block_name)) | ||||
|  | ||||
|     self.message     = 'CifarResNet : Block : {:}, Depth : {:}, Layers for each block : {:}'.format(block_name, depth, layer_blocks) | ||||
|         self.message = "CifarResNet : Block : {:}, Depth : {:}, Layers for each block : {:}".format( | ||||
|             block_name, depth, layer_blocks | ||||
|         ) | ||||
|         self.num_classes = num_classes | ||||
|         self.channels = [16] | ||||
|         self.layers = nn.ModuleList([ConvBNReLU(3, 16, 3, 1, 1, False, True)]) | ||||
| @@ -130,11 +141,23 @@ class CifarResNet(nn.Module): | ||||
|                 module = block(iC, planes, stride) | ||||
|                 self.channels.append(module.out_dim) | ||||
|                 self.layers.append(module) | ||||
|         self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride) | ||||
|                 self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format( | ||||
|                     stage, | ||||
|                     iL, | ||||
|                     layer_blocks, | ||||
|                     len(self.layers) - 1, | ||||
|                     iC, | ||||
|                     module.out_dim, | ||||
|                     stride, | ||||
|                 ) | ||||
|  | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(module.out_dim, num_classes) | ||||
|     assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) | ||||
|         assert ( | ||||
|             sum(x.num_conv for x in self.layers) + 1 == depth | ||||
|         ), "invalid depth check {:} vs {:}".format( | ||||
|             sum(x.num_conv for x in self.layers) + 1, depth | ||||
|         ) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|         if zero_init_residual: | ||||
|   | ||||
| @@ -9,17 +9,23 @@ class WideBasicblock(nn.Module): | ||||
|         super(WideBasicblock, self).__init__() | ||||
|  | ||||
|         self.bn_a = nn.BatchNorm2d(inplanes) | ||||
|     self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | ||||
|         self.conv_a = nn.Conv2d( | ||||
|             inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False | ||||
|         ) | ||||
|  | ||||
|         self.bn_b = nn.BatchNorm2d(planes) | ||||
|         if dropout: | ||||
|             self.dropout = nn.Dropout2d(p=0.5, inplace=True) | ||||
|         else: | ||||
|             self.dropout = None | ||||
|     self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) | ||||
|         self.conv_b = nn.Conv2d( | ||||
|             planes, planes, kernel_size=3, stride=1, padding=1, bias=False | ||||
|         ) | ||||
|  | ||||
|         if inplanes != planes: | ||||
|       self.downsample = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False) | ||||
|             self.downsample = nn.Conv2d( | ||||
|                 inplanes, planes, kernel_size=1, stride=stride, padding=0, bias=False | ||||
|             ) | ||||
|         else: | ||||
|             self.downsample = None | ||||
|  | ||||
| @@ -46,24 +52,39 @@ class CifarWideResNet(nn.Module): | ||||
|     ResNet optimized for the Cifar dataset, as specified in | ||||
|     https://arxiv.org/abs/1512.03385.pdf | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, depth, widen_factor, num_classes, dropout): | ||||
|         super(CifarWideResNet, self).__init__() | ||||
|  | ||||
|         # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model | ||||
|     assert (depth - 4) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' | ||||
|         assert (depth - 4) % 6 == 0, "depth should be one of 20, 32, 44, 56, 110" | ||||
|         layer_blocks = (depth - 4) // 6 | ||||
|     print ('CifarPreResNet : Depth : {} , Layers for each block : {}'.format(depth, layer_blocks)) | ||||
|         print( | ||||
|             "CifarPreResNet : Depth : {} , Layers for each block : {}".format( | ||||
|                 depth, layer_blocks | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         self.num_classes = num_classes | ||||
|         self.dropout = dropout | ||||
|         self.conv_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) | ||||
|  | ||||
|     self.message  = 'Wide ResNet : depth={:}, widen_factor={:}, class={:}'.format(depth, widen_factor, num_classes) | ||||
|         self.message = "Wide ResNet : depth={:}, widen_factor={:}, class={:}".format( | ||||
|             depth, widen_factor, num_classes | ||||
|         ) | ||||
|         self.inplanes = 16 | ||||
|     self.stage_1 = self._make_layer(WideBasicblock, 16*widen_factor, layer_blocks, 1) | ||||
|     self.stage_2 = self._make_layer(WideBasicblock, 32*widen_factor, layer_blocks, 2) | ||||
|     self.stage_3 = self._make_layer(WideBasicblock, 64*widen_factor, layer_blocks, 2) | ||||
|     self.lastact = nn.Sequential(nn.BatchNorm2d(64*widen_factor), nn.ReLU(inplace=True)) | ||||
|         self.stage_1 = self._make_layer( | ||||
|             WideBasicblock, 16 * widen_factor, layer_blocks, 1 | ||||
|         ) | ||||
|         self.stage_2 = self._make_layer( | ||||
|             WideBasicblock, 32 * widen_factor, layer_blocks, 2 | ||||
|         ) | ||||
|         self.stage_3 = self._make_layer( | ||||
|             WideBasicblock, 64 * widen_factor, layer_blocks, 2 | ||||
|         ) | ||||
|         self.lastact = nn.Sequential( | ||||
|             nn.BatchNorm2d(64 * widen_factor), nn.ReLU(inplace=True) | ||||
|         ) | ||||
|         self.avgpool = nn.AvgPool2d(8) | ||||
|         self.classifier = nn.Linear(64 * widen_factor, num_classes) | ||||
|  | ||||
|   | ||||
| @@ -7,7 +7,15 @@ class ConvBNReLU(nn.Module): | ||||
|     def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): | ||||
|         super(ConvBNReLU, self).__init__() | ||||
|         padding = (kernel_size - 1) // 2 | ||||
|     self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False) | ||||
|         self.conv = nn.Conv2d( | ||||
|             in_planes, | ||||
|             out_planes, | ||||
|             kernel_size, | ||||
|             stride, | ||||
|             padding, | ||||
|             groups=groups, | ||||
|             bias=False, | ||||
|         ) | ||||
|         self.bn = nn.BatchNorm2d(out_planes) | ||||
|         self.relu = nn.ReLU6(inplace=True) | ||||
|  | ||||
| @@ -31,13 +39,15 @@ class InvertedResidual(nn.Module): | ||||
|         if expand_ratio != 1: | ||||
|             # pw | ||||
|             layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) | ||||
|     layers.extend([ | ||||
|         layers.extend( | ||||
|             [ | ||||
|                 # dw | ||||
|                 ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), | ||||
|                 # pw-linear | ||||
|                 nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), | ||||
|                 nn.BatchNorm2d(oup), | ||||
|     ]) | ||||
|             ] | ||||
|         ) | ||||
|         self.conv = nn.Sequential(*layers) | ||||
|  | ||||
|     def forward(self, x): | ||||
| @@ -48,12 +58,14 @@ class InvertedResidual(nn.Module): | ||||
|  | ||||
|  | ||||
| class MobileNetV2(nn.Module): | ||||
|   def __init__(self, num_classes, width_mult, input_channel, last_channel, block_name, dropout): | ||||
|     def __init__( | ||||
|         self, num_classes, width_mult, input_channel, last_channel, block_name, dropout | ||||
|     ): | ||||
|         super(MobileNetV2, self).__init__() | ||||
|     if block_name == 'InvertedResidual': | ||||
|         if block_name == "InvertedResidual": | ||||
|             block = InvertedResidual | ||||
|         else: | ||||
|       raise ValueError('invalid block name : {:}'.format(block_name)) | ||||
|             raise ValueError("invalid block name : {:}".format(block_name)) | ||||
|         inverted_residual_setting = [ | ||||
|             # t, c,  n, s | ||||
|             [1, 16, 1, 1], | ||||
| @@ -74,7 +86,9 @@ class MobileNetV2(nn.Module): | ||||
|             output_channel = int(c * width_mult) | ||||
|             for i in range(n): | ||||
|                 stride = s if i == 0 else 1 | ||||
|         features.append(block(input_channel, output_channel, stride, expand_ratio=t)) | ||||
|                 features.append( | ||||
|                     block(input_channel, output_channel, stride, expand_ratio=t) | ||||
|                 ) | ||||
|                 input_channel = output_channel | ||||
|         # building last several layers | ||||
|         features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) | ||||
| @@ -86,7 +100,9 @@ class MobileNetV2(nn.Module): | ||||
|             nn.Dropout(dropout), | ||||
|             nn.Linear(self.last_channel, num_classes), | ||||
|         ) | ||||
|     self.message = 'MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}'.format(width_mult, input_channel, last_channel, block_name, dropout) | ||||
|         self.message = "MobileNetV2 : width_mult={:}, in-C={:}, last-C={:}, block={:}, dropout={:}".format( | ||||
|             width_mult, input_channel, last_channel, block_name, dropout | ||||
|         ) | ||||
|  | ||||
|         # weight initialization | ||||
|         self.apply(initialize_resnet) | ||||
|   | ||||
| @@ -2,8 +2,17 @@ | ||||
| import torch.nn as nn | ||||
| from .initialization import initialize_resnet | ||||
|  | ||||
|  | ||||
| def conv3x3(in_planes, out_planes, stride=1, groups=1): | ||||
|   return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) | ||||
|     return nn.Conv2d( | ||||
|         in_planes, | ||||
|         out_planes, | ||||
|         kernel_size=3, | ||||
|         stride=stride, | ||||
|         padding=1, | ||||
|         groups=groups, | ||||
|         bias=False, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def conv1x1(in_planes, out_planes, stride=1): | ||||
| @@ -13,10 +22,12 @@ def conv1x1(in_planes, out_planes, stride=1): | ||||
| class BasicBlock(nn.Module): | ||||
|     expansion = 1 | ||||
|  | ||||
|   def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64): | ||||
|     def __init__( | ||||
|         self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64 | ||||
|     ): | ||||
|         super(BasicBlock, self).__init__() | ||||
|         if groups != 1 or base_width != 64: | ||||
|       raise ValueError('BasicBlock only supports groups=1 and base_width=64') | ||||
|             raise ValueError("BasicBlock only supports groups=1 and base_width=64") | ||||
|         # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | ||||
|         self.conv1 = conv3x3(inplanes, planes, stride) | ||||
|         self.bn1 = nn.BatchNorm2d(planes) | ||||
| @@ -48,9 +59,11 @@ class BasicBlock(nn.Module): | ||||
| class Bottleneck(nn.Module): | ||||
|     expansion = 4 | ||||
|  | ||||
|   def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64): | ||||
|     def __init__( | ||||
|         self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64 | ||||
|     ): | ||||
|         super(Bottleneck, self).__init__() | ||||
|     width = int(planes * (base_width / 64.)) * groups | ||||
|         width = int(planes * (base_width / 64.0)) * groups | ||||
|         # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | ||||
|         self.conv1 = conv1x1(inplanes, width) | ||||
|         self.bn1 = nn.BatchNorm2d(width) | ||||
| @@ -86,36 +99,65 @@ class Bottleneck(nn.Module): | ||||
|  | ||||
|  | ||||
| class ResNet(nn.Module): | ||||
|  | ||||
|   def __init__(self, block_name, layers, deep_stem, num_classes, zero_init_residual, groups, width_per_group): | ||||
|     def __init__( | ||||
|         self, | ||||
|         block_name, | ||||
|         layers, | ||||
|         deep_stem, | ||||
|         num_classes, | ||||
|         zero_init_residual, | ||||
|         groups, | ||||
|         width_per_group, | ||||
|     ): | ||||
|         super(ResNet, self).__init__() | ||||
|  | ||||
|         # planes = [int(width_per_group * groups * 2 ** i) for i in range(4)] | ||||
|     if block_name == 'BasicBlock'  : block= BasicBlock | ||||
|     elif block_name == 'Bottleneck': block= Bottleneck | ||||
|     else                           : raise ValueError('invalid block-name : {:}'.format(block_name)) | ||||
|         if block_name == "BasicBlock": | ||||
|             block = BasicBlock | ||||
|         elif block_name == "Bottleneck": | ||||
|             block = Bottleneck | ||||
|         else: | ||||
|             raise ValueError("invalid block-name : {:}".format(block_name)) | ||||
|  | ||||
|         if not deep_stem: | ||||
|             self.conv = nn.Sequential( | ||||
|                 nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), | ||||
|                    nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | ||||
|                 nn.BatchNorm2d(64), | ||||
|                 nn.ReLU(inplace=True), | ||||
|             ) | ||||
|         else: | ||||
|             self.conv = nn.Sequential( | ||||
|                 nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False), | ||||
|                    nn.BatchNorm2d(32), nn.ReLU(inplace=True), | ||||
|                 nn.BatchNorm2d(32), | ||||
|                 nn.ReLU(inplace=True), | ||||
|                 nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False), | ||||
|                    nn.BatchNorm2d(32), nn.ReLU(inplace=True), | ||||
|                 nn.BatchNorm2d(32), | ||||
|                 nn.ReLU(inplace=True), | ||||
|                 nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False), | ||||
|                    nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | ||||
|                 nn.BatchNorm2d(64), | ||||
|                 nn.ReLU(inplace=True), | ||||
|             ) | ||||
|         self.inplanes = 64 | ||||
|         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||||
|     self.layer1 = self._make_layer(block, 64 , layers[0], stride=1, groups=groups, base_width=width_per_group) | ||||
|     self.layer2 = self._make_layer(block, 128, layers[1], stride=2, groups=groups, base_width=width_per_group) | ||||
|     self.layer3 = self._make_layer(block, 256, layers[2], stride=2, groups=groups, base_width=width_per_group) | ||||
|     self.layer4 = self._make_layer(block, 512, layers[3], stride=2, groups=groups, base_width=width_per_group) | ||||
|         self.layer1 = self._make_layer( | ||||
|             block, 64, layers[0], stride=1, groups=groups, base_width=width_per_group | ||||
|         ) | ||||
|         self.layer2 = self._make_layer( | ||||
|             block, 128, layers[1], stride=2, groups=groups, base_width=width_per_group | ||||
|         ) | ||||
|         self.layer3 = self._make_layer( | ||||
|             block, 256, layers[2], stride=2, groups=groups, base_width=width_per_group | ||||
|         ) | ||||
|         self.layer4 = self._make_layer( | ||||
|             block, 512, layers[3], stride=2, groups=groups, base_width=width_per_group | ||||
|         ) | ||||
|         self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||||
|         self.fc = nn.Linear(512 * block.expansion, num_classes) | ||||
|     self.message = 'block = {:}, layers = {:}, deep_stem = {:}, num_classes = {:}'.format(block, layers, deep_stem, num_classes) | ||||
|         self.message = ( | ||||
|             "block = {:}, layers = {:}, deep_stem = {:}, num_classes = {:}".format( | ||||
|                 block, layers, deep_stem, num_classes | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         self.apply(initialize_resnet) | ||||
|  | ||||
| @@ -143,10 +185,13 @@ class ResNet(nn.Module): | ||||
|                     conv1x1(self.inplanes, planes * block.expansion, stride), | ||||
|                     nn.BatchNorm2d(planes * block.expansion), | ||||
|                 ) | ||||
|       else: raise ValueError('invalid stride [{:}] for downsample'.format(stride)) | ||||
|             else: | ||||
|                 raise ValueError("invalid stride [{:}] for downsample".format(stride)) | ||||
|  | ||||
|         layers = [] | ||||
|     layers.append(block(self.inplanes, planes, stride, downsample, groups, base_width)) | ||||
|         layers.append( | ||||
|             block(self.inplanes, planes, stride, downsample, groups, base_width) | ||||
|         ) | ||||
|         self.inplanes = planes * block.expansion | ||||
|         for _ in range(1, blocks): | ||||
|             layers.append(block(self.inplanes, planes, 1, None, groups, base_width)) | ||||
|   | ||||
| @@ -6,7 +6,9 @@ import torch.nn as nn | ||||
|  | ||||
|  | ||||
| def additive_func(A, B): | ||||
|   assert A.dim() == B.dim() and A.size(0) == B.size(0), '{:} vs {:}'.format(A.size(), B.size()) | ||||
|     assert A.dim() == B.dim() and A.size(0) == B.size(0), "{:} vs {:}".format( | ||||
|         A.size(), B.size() | ||||
|     ) | ||||
|     C = min(A.size(1), B.size(1)) | ||||
|     if A.size(1) == B.size(1): | ||||
|         return A + B | ||||
| @@ -24,11 +26,12 @@ def change_key(key, value): | ||||
|     def func(m): | ||||
|         if hasattr(m, key): | ||||
|             setattr(m, key, value) | ||||
|  | ||||
|     return func | ||||
|  | ||||
|  | ||||
| def parse_channel_info(xstring): | ||||
|   blocks = xstring.split(' ') | ||||
|   blocks = [x.split('-') for x in blocks] | ||||
|     blocks = xstring.split(" ") | ||||
|     blocks = [x.split("-") for x in blocks] | ||||
|     blocks = [[int(_) for _ in x] for x in blocks] | ||||
|     return blocks | ||||
|   | ||||
| @@ -5,9 +5,17 @@ from os import path as osp | ||||
| from typing import List, Text | ||||
| import torch | ||||
|  | ||||
| __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \ | ||||
|            'obtain_model', 'obtain_search_model', 'load_net_from_checkpoint', \ | ||||
|            'CellStructure', 'CellArchitectures' | ||||
| __all__ = [ | ||||
|     "change_key", | ||||
|     "get_cell_based_tiny_net", | ||||
|     "get_search_spaces", | ||||
|     "get_cifar_models", | ||||
|     "get_imagenet_models", | ||||
|     "obtain_model", | ||||
|     "obtain_search_model", | ||||
|     "load_net_from_checkpoint", | ||||
|     "CellStructure", | ||||
|     "CellArchitectures", | ||||
| ] | ||||
|  | ||||
| # useful modules | ||||
| @@ -18,178 +26,301 @@ from models.cell_searchs import CellStructure, CellArchitectures | ||||
|  | ||||
| # Cell-based NAS Models | ||||
| def get_cell_based_tiny_net(config): | ||||
|   if isinstance(config, dict): config = dict2config(config, None) # to support the argument being a dict | ||||
|   super_type = getattr(config, 'super_type', 'basic') | ||||
|   group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM', 'generic'] | ||||
|   if super_type == 'basic' and config.name in group_names: | ||||
|     if isinstance(config, dict): | ||||
|         config = dict2config(config, None)  # to support the argument being a dict | ||||
|     super_type = getattr(config, "super_type", "basic") | ||||
|     group_names = ["DARTS-V1", "DARTS-V2", "GDAS", "SETN", "ENAS", "RANDOM", "generic"] | ||||
|     if super_type == "basic" and config.name in group_names: | ||||
|         from .cell_searchs import nas201_super_nets as nas_super_nets | ||||
|  | ||||
|         try: | ||||
|       return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space, config.affine, config.track_running_stats) | ||||
|             return nas_super_nets[config.name]( | ||||
|                 config.C, | ||||
|                 config.N, | ||||
|                 config.max_nodes, | ||||
|                 config.num_classes, | ||||
|                 config.space, | ||||
|                 config.affine, | ||||
|                 config.track_running_stats, | ||||
|             ) | ||||
|         except: | ||||
|       return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space) | ||||
|   elif super_type == 'search-shape': | ||||
|             return nas_super_nets[config.name]( | ||||
|                 config.C, config.N, config.max_nodes, config.num_classes, config.space | ||||
|             ) | ||||
|     elif super_type == "search-shape": | ||||
|         from .shape_searchs import GenericNAS301Model | ||||
|  | ||||
|         genotype = CellStructure.str2structure(config.genotype) | ||||
|     return GenericNAS301Model(config.candidate_Cs, config.max_num_Cs, genotype, config.num_classes, config.affine, config.track_running_stats) | ||||
|   elif super_type == 'nasnet-super': | ||||
|         return GenericNAS301Model( | ||||
|             config.candidate_Cs, | ||||
|             config.max_num_Cs, | ||||
|             genotype, | ||||
|             config.num_classes, | ||||
|             config.affine, | ||||
|             config.track_running_stats, | ||||
|         ) | ||||
|     elif super_type == "nasnet-super": | ||||
|         from .cell_searchs import nasnet_super_nets as nas_super_nets | ||||
|     return nas_super_nets[config.name](config.C, config.N, config.steps, config.multiplier, \ | ||||
|                     config.stem_multiplier, config.num_classes, config.space, config.affine, config.track_running_stats) | ||||
|   elif config.name == 'infer.tiny': | ||||
|  | ||||
|         return nas_super_nets[config.name]( | ||||
|             config.C, | ||||
|             config.N, | ||||
|             config.steps, | ||||
|             config.multiplier, | ||||
|             config.stem_multiplier, | ||||
|             config.num_classes, | ||||
|             config.space, | ||||
|             config.affine, | ||||
|             config.track_running_stats, | ||||
|         ) | ||||
|     elif config.name == "infer.tiny": | ||||
|         from .cell_infers import TinyNetwork | ||||
|     if hasattr(config, 'genotype'): | ||||
|  | ||||
|         if hasattr(config, "genotype"): | ||||
|             genotype = config.genotype | ||||
|     elif hasattr(config, 'arch_str'): | ||||
|         elif hasattr(config, "arch_str"): | ||||
|             genotype = CellStructure.str2structure(config.arch_str) | ||||
|     else: raise ValueError('Can not find genotype from this config : {:}'.format(config)) | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 "Can not find genotype from this config : {:}".format(config) | ||||
|             ) | ||||
|         return TinyNetwork(config.C, config.N, genotype, config.num_classes) | ||||
|   elif config.name == 'infer.shape.tiny': | ||||
|     elif config.name == "infer.shape.tiny": | ||||
|         from .shape_infers import DynamicShapeTinyNet | ||||
|  | ||||
|         if isinstance(config.channels, str): | ||||
|       channels = tuple([int(x) for x in config.channels.split(':')]) | ||||
|     else: channels = config.channels | ||||
|             channels = tuple([int(x) for x in config.channels.split(":")]) | ||||
|         else: | ||||
|             channels = config.channels | ||||
|         genotype = CellStructure.str2structure(config.genotype) | ||||
|         return DynamicShapeTinyNet(channels, genotype, config.num_classes) | ||||
|   elif config.name == 'infer.nasnet-cifar': | ||||
|     elif config.name == "infer.nasnet-cifar": | ||||
|         from .cell_infers import NASNetonCIFAR | ||||
|  | ||||
|         raise NotImplementedError | ||||
|     else: | ||||
|     raise ValueError('invalid network name : {:}'.format(config.name)) | ||||
|         raise ValueError("invalid network name : {:}".format(config.name)) | ||||
|  | ||||
|  | ||||
| # obtain the search space, i.e., a dict mapping the operation name into a python-function for this op | ||||
| def get_search_spaces(xtype, name) -> List[Text]: | ||||
|   if xtype == 'cell' or xtype == 'tss':  # The topology search space. | ||||
|     if xtype == "cell" or xtype == "tss":  # The topology search space. | ||||
|         from .cell_operations import SearchSpaceNames | ||||
|     assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys()) | ||||
|  | ||||
|         assert name in SearchSpaceNames, "invalid name [{:}] in {:}".format( | ||||
|             name, SearchSpaceNames.keys() | ||||
|         ) | ||||
|         return SearchSpaceNames[name] | ||||
|   elif xtype == 'sss':  # The size search space. | ||||
|     if name in ['nats-bench', 'nats-bench-size']: | ||||
|       return {'candidates': [8, 16, 24, 32, 40, 48, 56, 64], | ||||
|               'numbers': 5} | ||||
|     elif xtype == "sss":  # The size search space. | ||||
|         if name in ["nats-bench", "nats-bench-size"]: | ||||
|             return {"candidates": [8, 16, 24, 32, 40, 48, 56, 64], "numbers": 5} | ||||
|         else: | ||||
|       raise ValueError('Invalid name : {:}'.format(name)) | ||||
|             raise ValueError("Invalid name : {:}".format(name)) | ||||
|     else: | ||||
|     raise ValueError('invalid search-space type is {:}'.format(xtype)) | ||||
|         raise ValueError("invalid search-space type is {:}".format(xtype)) | ||||
|  | ||||
|  | ||||
| def get_cifar_models(config, extra_path=None): | ||||
|   super_type = getattr(config, 'super_type', 'basic') | ||||
|   if super_type == 'basic': | ||||
|     super_type = getattr(config, "super_type", "basic") | ||||
|     if super_type == "basic": | ||||
|         from .CifarResNet import CifarResNet | ||||
|         from .CifarDenseNet import DenseNet | ||||
|         from .CifarWideResNet import CifarWideResNet | ||||
|     if config.arch == 'resnet': | ||||
|       return CifarResNet(config.module, config.depth, config.class_num, config.zero_init_residual) | ||||
|     elif config.arch == 'densenet': | ||||
|       return DenseNet(config.growthRate, config.depth, config.reduction, config.class_num, config.bottleneck) | ||||
|     elif config.arch == 'wideresnet': | ||||
|       return CifarWideResNet(config.depth, config.wide_factor, config.class_num, config.dropout) | ||||
|  | ||||
|         if config.arch == "resnet": | ||||
|             return CifarResNet( | ||||
|                 config.module, config.depth, config.class_num, config.zero_init_residual | ||||
|             ) | ||||
|         elif config.arch == "densenet": | ||||
|             return DenseNet( | ||||
|                 config.growthRate, | ||||
|                 config.depth, | ||||
|                 config.reduction, | ||||
|                 config.class_num, | ||||
|                 config.bottleneck, | ||||
|             ) | ||||
|         elif config.arch == "wideresnet": | ||||
|             return CifarWideResNet( | ||||
|                 config.depth, config.wide_factor, config.class_num, config.dropout | ||||
|             ) | ||||
|         else: | ||||
|       raise ValueError('invalid module type : {:}'.format(config.arch)) | ||||
|   elif super_type.startswith('infer'): | ||||
|             raise ValueError("invalid module type : {:}".format(config.arch)) | ||||
|     elif super_type.startswith("infer"): | ||||
|         from .shape_infers import InferWidthCifarResNet | ||||
|         from .shape_infers import InferDepthCifarResNet | ||||
|         from .shape_infers import InferCifarResNet | ||||
|         from .cell_infers import NASNetonCIFAR | ||||
|     assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type) | ||||
|     infer_mode = super_type.split('-')[1] | ||||
|     if infer_mode == 'width': | ||||
|       return InferWidthCifarResNet(config.module, config.depth, config.xchannels, config.class_num, config.zero_init_residual) | ||||
|     elif infer_mode == 'depth': | ||||
|       return InferDepthCifarResNet(config.module, config.depth, config.xblocks, config.class_num, config.zero_init_residual) | ||||
|     elif infer_mode == 'shape': | ||||
|       return InferCifarResNet(config.module, config.depth, config.xblocks, config.xchannels, config.class_num, config.zero_init_residual) | ||||
|     elif infer_mode == 'nasnet.cifar': | ||||
|  | ||||
|         assert len(super_type.split("-")) == 2, "invalid super_type : {:}".format( | ||||
|             super_type | ||||
|         ) | ||||
|         infer_mode = super_type.split("-")[1] | ||||
|         if infer_mode == "width": | ||||
|             return InferWidthCifarResNet( | ||||
|                 config.module, | ||||
|                 config.depth, | ||||
|                 config.xchannels, | ||||
|                 config.class_num, | ||||
|                 config.zero_init_residual, | ||||
|             ) | ||||
|         elif infer_mode == "depth": | ||||
|             return InferDepthCifarResNet( | ||||
|                 config.module, | ||||
|                 config.depth, | ||||
|                 config.xblocks, | ||||
|                 config.class_num, | ||||
|                 config.zero_init_residual, | ||||
|             ) | ||||
|         elif infer_mode == "shape": | ||||
|             return InferCifarResNet( | ||||
|                 config.module, | ||||
|                 config.depth, | ||||
|                 config.xblocks, | ||||
|                 config.xchannels, | ||||
|                 config.class_num, | ||||
|                 config.zero_init_residual, | ||||
|             ) | ||||
|         elif infer_mode == "nasnet.cifar": | ||||
|             genotype = config.genotype | ||||
|             if extra_path is not None:  # reload genotype by extra_path | ||||
|         if not osp.isfile(extra_path): raise ValueError('invalid extra_path : {:}'.format(extra_path)) | ||||
|                 if not osp.isfile(extra_path): | ||||
|                     raise ValueError("invalid extra_path : {:}".format(extra_path)) | ||||
|                 xdata = torch.load(extra_path) | ||||
|         current_epoch = xdata['epoch'] | ||||
|         genotype = xdata['genotypes'][current_epoch-1] | ||||
|       C = config.C if hasattr(config, 'C') else config.ichannel | ||||
|       N = config.N if hasattr(config, 'N') else config.layers | ||||
|       return NASNetonCIFAR(C, N, config.stem_multi, config.class_num, genotype, config.auxiliary) | ||||
|                 current_epoch = xdata["epoch"] | ||||
|                 genotype = xdata["genotypes"][current_epoch - 1] | ||||
|             C = config.C if hasattr(config, "C") else config.ichannel | ||||
|             N = config.N if hasattr(config, "N") else config.layers | ||||
|             return NASNetonCIFAR( | ||||
|                 C, N, config.stem_multi, config.class_num, genotype, config.auxiliary | ||||
|             ) | ||||
|         else: | ||||
|       raise ValueError('invalid infer-mode : {:}'.format(infer_mode)) | ||||
|             raise ValueError("invalid infer-mode : {:}".format(infer_mode)) | ||||
|     else: | ||||
|     raise ValueError('invalid super-type : {:}'.format(super_type)) | ||||
|         raise ValueError("invalid super-type : {:}".format(super_type)) | ||||
|  | ||||
|  | ||||
| def get_imagenet_models(config): | ||||
|   super_type = getattr(config, 'super_type', 'basic') | ||||
|   if super_type == 'basic': | ||||
|     super_type = getattr(config, "super_type", "basic") | ||||
|     if super_type == "basic": | ||||
|         from .ImageNet_ResNet import ResNet | ||||
|         from .ImageNet_MobileNetV2 import MobileNetV2 | ||||
|     if config.arch == 'resnet': | ||||
|       return ResNet(config.block_name, config.layers, config.deep_stem, config.class_num, config.zero_init_residual, config.groups, config.width_per_group) | ||||
|     elif config.arch == 'mobilenet_v2': | ||||
|       return MobileNetV2(config.class_num, config.width_multi, config.input_channel, config.last_channel, 'InvertedResidual', config.dropout) | ||||
|  | ||||
|         if config.arch == "resnet": | ||||
|             return ResNet( | ||||
|                 config.block_name, | ||||
|                 config.layers, | ||||
|                 config.deep_stem, | ||||
|                 config.class_num, | ||||
|                 config.zero_init_residual, | ||||
|                 config.groups, | ||||
|                 config.width_per_group, | ||||
|             ) | ||||
|         elif config.arch == "mobilenet_v2": | ||||
|             return MobileNetV2( | ||||
|                 config.class_num, | ||||
|                 config.width_multi, | ||||
|                 config.input_channel, | ||||
|                 config.last_channel, | ||||
|                 "InvertedResidual", | ||||
|                 config.dropout, | ||||
|             ) | ||||
|         else: | ||||
|       raise ValueError('invalid arch : {:}'.format( config.arch )) | ||||
|   elif super_type.startswith('infer'): # NAS searched architecture | ||||
|     assert len(super_type.split('-')) == 2, 'invalid super_type : {:}'.format(super_type) | ||||
|     infer_mode = super_type.split('-')[1] | ||||
|     if infer_mode == 'shape': | ||||
|             raise ValueError("invalid arch : {:}".format(config.arch)) | ||||
|     elif super_type.startswith("infer"):  # NAS searched architecture | ||||
|         assert len(super_type.split("-")) == 2, "invalid super_type : {:}".format( | ||||
|             super_type | ||||
|         ) | ||||
|         infer_mode = super_type.split("-")[1] | ||||
|         if infer_mode == "shape": | ||||
|             from .shape_infers import InferImagenetResNet | ||||
|             from .shape_infers import InferMobileNetV2 | ||||
|       if config.arch == 'resnet': | ||||
|         return InferImagenetResNet(config.block_name, config.layers, config.xblocks, config.xchannels, config.deep_stem, config.class_num, config.zero_init_residual) | ||||
|  | ||||
|             if config.arch == "resnet": | ||||
|                 return InferImagenetResNet( | ||||
|                     config.block_name, | ||||
|                     config.layers, | ||||
|                     config.xblocks, | ||||
|                     config.xchannels, | ||||
|                     config.deep_stem, | ||||
|                     config.class_num, | ||||
|                     config.zero_init_residual, | ||||
|                 ) | ||||
|             elif config.arch == "MobileNetV2": | ||||
|         return InferMobileNetV2(config.class_num, config.xchannels, config.xblocks, config.dropout) | ||||
|                 return InferMobileNetV2( | ||||
|                     config.class_num, config.xchannels, config.xblocks, config.dropout | ||||
|                 ) | ||||
|             else: | ||||
|         raise ValueError('invalid arch-mode : {:}'.format(config.arch)) | ||||
|                 raise ValueError("invalid arch-mode : {:}".format(config.arch)) | ||||
|         else: | ||||
|       raise ValueError('invalid infer-mode : {:}'.format(infer_mode)) | ||||
|             raise ValueError("invalid infer-mode : {:}".format(infer_mode)) | ||||
|     else: | ||||
|     raise ValueError('invalid super-type : {:}'.format(super_type)) | ||||
|         raise ValueError("invalid super-type : {:}".format(super_type)) | ||||
|  | ||||
|  | ||||
| # Try to obtain the network by config. | ||||
| def obtain_model(config, extra_path=None): | ||||
|   if config.dataset == 'cifar': | ||||
|     if config.dataset == "cifar": | ||||
|         return get_cifar_models(config, extra_path) | ||||
|   elif config.dataset == 'imagenet': | ||||
|     elif config.dataset == "imagenet": | ||||
|         return get_imagenet_models(config) | ||||
|     else: | ||||
|     raise ValueError('invalid dataset in the model config : {:}'.format(config)) | ||||
|         raise ValueError("invalid dataset in the model config : {:}".format(config)) | ||||
|  | ||||
|  | ||||
| def obtain_search_model(config): | ||||
|   if config.dataset == 'cifar': | ||||
|     if config.arch == 'resnet': | ||||
|     if config.dataset == "cifar": | ||||
|         if config.arch == "resnet": | ||||
|             from .shape_searchs import SearchWidthCifarResNet | ||||
|             from .shape_searchs import SearchDepthCifarResNet | ||||
|             from .shape_searchs import SearchShapeCifarResNet | ||||
|       if config.search_mode == 'width': | ||||
|         return SearchWidthCifarResNet(config.module, config.depth, config.class_num) | ||||
|       elif config.search_mode == 'depth': | ||||
|         return SearchDepthCifarResNet(config.module, config.depth, config.class_num) | ||||
|       elif config.search_mode == 'shape': | ||||
|         return SearchShapeCifarResNet(config.module, config.depth, config.class_num) | ||||
|       else: raise ValueError('invalid search mode : {:}'.format(config.search_mode)) | ||||
|     elif config.arch == 'simres': | ||||
|  | ||||
|             if config.search_mode == "width": | ||||
|                 return SearchWidthCifarResNet( | ||||
|                     config.module, config.depth, config.class_num | ||||
|                 ) | ||||
|             elif config.search_mode == "depth": | ||||
|                 return SearchDepthCifarResNet( | ||||
|                     config.module, config.depth, config.class_num | ||||
|                 ) | ||||
|             elif config.search_mode == "shape": | ||||
|                 return SearchShapeCifarResNet( | ||||
|                     config.module, config.depth, config.class_num | ||||
|                 ) | ||||
|             else: | ||||
|                 raise ValueError("invalid search mode : {:}".format(config.search_mode)) | ||||
|         elif config.arch == "simres": | ||||
|             from .shape_searchs import SearchWidthSimResNet | ||||
|       if config.search_mode == 'width': | ||||
|  | ||||
|             if config.search_mode == "width": | ||||
|                 return SearchWidthSimResNet(config.depth, config.class_num) | ||||
|       else: raise ValueError('invalid search mode : {:}'.format(config.search_mode)) | ||||
|             else: | ||||
|       raise ValueError('invalid arch : {:} for dataset [{:}]'.format(config.arch, config.dataset)) | ||||
|   elif config.dataset == 'imagenet': | ||||
|                 raise ValueError("invalid search mode : {:}".format(config.search_mode)) | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 "invalid arch : {:} for dataset [{:}]".format( | ||||
|                     config.arch, config.dataset | ||||
|                 ) | ||||
|             ) | ||||
|     elif config.dataset == "imagenet": | ||||
|         from .shape_searchs import SearchShapeImagenetResNet | ||||
|     assert config.search_mode == 'shape', 'invalid search-mode : {:}'.format( config.search_mode ) | ||||
|     if config.arch == 'resnet': | ||||
|       return SearchShapeImagenetResNet(config.block_name, config.layers, config.deep_stem, config.class_num) | ||||
|  | ||||
|         assert config.search_mode == "shape", "invalid search-mode : {:}".format( | ||||
|             config.search_mode | ||||
|         ) | ||||
|         if config.arch == "resnet": | ||||
|             return SearchShapeImagenetResNet( | ||||
|                 config.block_name, config.layers, config.deep_stem, config.class_num | ||||
|             ) | ||||
|         else: | ||||
|       raise ValueError('invalid model config : {:}'.format(config)) | ||||
|             raise ValueError("invalid model config : {:}".format(config)) | ||||
|     else: | ||||
|     raise ValueError('invalid dataset in the model config : {:}'.format(config)) | ||||
|         raise ValueError("invalid dataset in the model config : {:}".format(config)) | ||||
|  | ||||
|  | ||||
| def load_net_from_checkpoint(checkpoint): | ||||
|   assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint) | ||||
|     assert osp.isfile(checkpoint), "checkpoint {:} does not exist".format(checkpoint) | ||||
|     checkpoint = torch.load(checkpoint) | ||||
|   model_config = dict2config(checkpoint['model-config'], None) | ||||
|     model_config = dict2config(checkpoint["model-config"], None) | ||||
|     model = obtain_model(model_config) | ||||
|   model.load_state_dict(checkpoint['base-model']) | ||||
|     model.load_state_dict(checkpoint["base-model"]) | ||||
|     return model | ||||
|   | ||||
| @@ -21,6 +21,10 @@ def get_model(config: Dict[Text, Any], **kwargs): | ||||
|         act_cls = super_name2activation[kwargs["act_cls"]] | ||||
|         norm_cls = super_name2norm[kwargs["norm_cls"]] | ||||
|         mean, std = kwargs.get("mean", None), kwargs.get("std", None) | ||||
|         if "hidden_dim" in kwargs: | ||||
|             hidden_dim1 = kwargs.get("hidden_dim") | ||||
|             hidden_dim2 = kwargs.get("hidden_dim") | ||||
|         else: | ||||
|             hidden_dim1 = kwargs.get("hidden_dim1", 200) | ||||
|             hidden_dim2 = kwargs.get("hidden_dim2", 100) | ||||
|         model = SuperSequential( | ||||
| @@ -34,4 +38,3 @@ def get_model(config: Dict[Text, Any], **kwargs): | ||||
|     else: | ||||
|         raise TypeError("Unkonwn model type: {:}".format(model_type)) | ||||
|     return model | ||||
|  | ||||
|   | ||||
| @@ -59,6 +59,9 @@ class TensorContainer: | ||||
|         for tensor in self._tensors: | ||||
|             tensor.requires_grad_(requires_grad) | ||||
|  | ||||
|     def parameters(self): | ||||
|         return self._tensors | ||||
|  | ||||
|     @property | ||||
|     def tensors(self): | ||||
|         return self._tensors | ||||
|   | ||||
		Reference in New Issue
	
	Block a user