From 09d68c6375181ebe5b45abbd4f8d6c7e08134445 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Tue, 19 Nov 2019 11:58:04 +1100 Subject: [PATCH] update GDAS --- configs/nas-benchmark/LESS.config | 13 ++ exps/algos/BOHB.py | 31 ++- exps/algos/RANDOM.py | 35 ++-- exps/algos/R_EA.py | 41 ++-- exps/algos/reinforce.py | 37 ++-- lib/aa_nas_api/api.py | 10 +- lib/models/__init__.py | 9 +- lib/models/cell_searchs/search_cells.py | 40 +--- lib/models/cell_searchs/search_model_gdas.py | 12 +- lib/models/l2s_cell_searchs/__init__.py | 17 ++ lib/models/l2s_cell_searchs/_test_module.py | 12 ++ lib/models/l2s_cell_searchs/genotypes.py | 197 ++++++++++++++++++ lib/models/l2s_cell_searchs/search_cells.py | 148 +++++++++++++ .../l2s_cell_searchs/search_model_darts_v1.py | 93 +++++++++ .../l2s_cell_searchs/search_model_darts_v2.py | 93 +++++++++ .../l2s_cell_searchs/search_model_enas.py | 94 +++++++++ .../search_model_enas_utils.py | 55 +++++ .../l2s_cell_searchs/search_model_gdas.py | 96 +++++++++ .../l2s_cell_searchs/search_model_random.py | 81 +++++++ .../l2s_cell_searchs/search_model_setn.py | 152 ++++++++++++++ 20 files changed, 1176 insertions(+), 90 deletions(-) create mode 100644 configs/nas-benchmark/LESS.config create mode 100644 lib/models/l2s_cell_searchs/__init__.py create mode 100644 lib/models/l2s_cell_searchs/_test_module.py create mode 100644 lib/models/l2s_cell_searchs/genotypes.py create mode 100644 lib/models/l2s_cell_searchs/search_cells.py create mode 100644 lib/models/l2s_cell_searchs/search_model_darts_v1.py create mode 100644 lib/models/l2s_cell_searchs/search_model_darts_v2.py create mode 100644 lib/models/l2s_cell_searchs/search_model_enas.py create mode 100644 lib/models/l2s_cell_searchs/search_model_enas_utils.py create mode 100644 lib/models/l2s_cell_searchs/search_model_gdas.py create mode 100644 lib/models/l2s_cell_searchs/search_model_random.py create mode 100644 lib/models/l2s_cell_searchs/search_model_setn.py diff --git a/configs/nas-benchmark/LESS.config b/configs/nas-benchmark/LESS.config new file mode 100644 index 0000000..1e3e559 --- /dev/null +++ b/configs/nas-benchmark/LESS.config @@ -0,0 +1,13 @@ +{ + "scheduler": ["str", "cos"], + "eta_min" : ["float", "0.0"], + "epochs" : ["int", "10"], + "warmup" : ["int", "0"], + "optim" : ["str", "SGD"], + "LR" : ["float", "0.1"], + "decay" : ["float", "0.0005"], + "momentum" : ["float", "0.9"], + "nesterov" : ["bool", "1"], + "criterion": ["str", "Softmax"], + "batch_size": ["int", "256"] +} diff --git a/exps/algos/BOHB.py b/exps/algos/BOHB.py index 20a4853..9a400ab 100644 --- a/exps/algos/BOHB.py +++ b/exps/algos/BOHB.py @@ -69,7 +69,7 @@ class MyWorker(Worker): 'info': None}) -def main(xargs): +def main(xargs, nas_bench): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False @@ -111,7 +111,7 @@ def main(xargs): ns_host, ns_port = NS.start() num_workers = 1 - nas_bench = AANASBenchAPI(xargs.arch_nas_dataset) + #nas_bench = AANASBenchAPI(xargs.arch_nas_dataset) logger.log('{:} Create AA-NAS-BENCH-API DONE'.format(time_string())) workers = [] for i in range(num_workers): @@ -140,15 +140,14 @@ def main(xargs): logger.log('Best found configuration: {:}'.format(id2config[incumbent]['config'])) best_arch = config2structure( id2config[incumbent]['config'] ) - if nas_bench is not None: - info = nas_bench.query_by_arch( best_arch ) - if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) - else : logger.log('{:}'.format(info)) + info = nas_bench.query_by_arch( best_arch ) + if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) + else : logger.log('{:}'.format(info)) logger.log('-'*100) logger.log('workers : {:}'.format(workers[0].test_time)) - logger.close() + return logger.log_dir, nas_bench.query_index_by_arch( best_arch ) @@ -175,5 +174,19 @@ if __name__ == '__main__': parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') parser.add_argument('--rand_seed', type=int, help='manual seed') args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) - main(args) + #if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) + if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): + nas_bench = None + else: + print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset)) + nas_bench = AANASBenchAPI(args.arch_nas_dataset) + if args.rand_seed < 0: + save_dir, all_indexes, num = None, [], 500 + for i in range(num): + print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num)) + args.rand_seed = random.randint(1, 100000) + save_dir, index = main(args, nas_bench) + all_indexes.append( index ) + torch.save(all_indexes, save_dir / 'results.pth') + else: + main(args, nas_bench) diff --git a/exps/algos/RANDOM.py b/exps/algos/RANDOM.py index 29d93d0..d367e71 100644 --- a/exps/algos/RANDOM.py +++ b/exps/algos/RANDOM.py @@ -19,7 +19,7 @@ from aa_nas_api import AANASBenchAPI from R_EA import train_and_eval, random_architecture_func -def main(xargs): +def main(xargs, nas_bench): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False @@ -51,12 +51,6 @@ def main(xargs): search_space = get_search_spaces('cell', xargs.search_space_name) random_arch = random_architecture_func(xargs.max_nodes, search_space) #x =random_arch() ; y = mutate_arch(x) - if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): - logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) - nas_bench = None - else: - logger.log('{:} build NAS-Benchmark-API from {:}'.format(time_string(), xargs.arch_nas_dataset)) - nas_bench = AANASBenchAPI(xargs.arch_nas_dataset) logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) best_arch, best_acc = None, -1 for idx in range(xargs.random_num): @@ -67,13 +61,12 @@ def main(xargs): logger.log('[{:03d}/{:03d}] : {:} : accuracy = {:.2f}%'.format(idx, xargs.random_num, arch, accuracy)) logger.log('{:} best arch is {:}, accuracy = {:.2f}%'.format(time_string(), best_arch, best_acc)) - if nas_bench is not None: - info = nas_bench.query_by_arch( best_arch ) - if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) - else : logger.log('{:}'.format(info)) + info = nas_bench.query_by_arch( best_arch ) + if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) + else : logger.log('{:}'.format(info)) logger.log('-'*100) - logger.close() + return logger.log_dir, nas_bench.query_index_by_arch( best_arch ) @@ -94,5 +87,19 @@ if __name__ == '__main__': parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') parser.add_argument('--rand_seed', type=int, help='manual seed') args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) - main(args) + #if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) + if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): + nas_bench = None + else: + print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset)) + nas_bench = AANASBenchAPI(args.arch_nas_dataset) + if args.rand_seed < 0: + save_dir, all_indexes, num = None, [], 500 + for i in range(num): + print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num)) + args.rand_seed = random.randint(1, 100000) + save_dir, index = main(args, nas_bench) + all_indexes.append( index ) + torch.save(all_indexes, save_dir / 'results.pth') + else: + main(args, nas_bench) diff --git a/exps/algos/R_EA.py b/exps/algos/R_EA.py index cd7d391..e66957c 100644 --- a/exps/algos/R_EA.py +++ b/exps/algos/R_EA.py @@ -60,7 +60,8 @@ def train_and_eval(arch, nas_bench, extra_info): arch_index = nas_bench.query_index_by_arch( arch ) assert arch_index >= 0, 'can not find this arch : {:}'.format(arch) info = nas_bench.arch2infos[ arch_index ] - _, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25) # use the validation accuracy after 25 training epochs + _, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs + #import pdb; pdb.set_trace() else: # train a model from scratch. raise ValueError('NOT IMPLEMENT YET') @@ -153,7 +154,7 @@ def regularized_evolution(cycles, population_size, sample_size, random_arch, mut return history -def main(xargs): +def main(xargs, nas_bench): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False @@ -186,12 +187,6 @@ def main(xargs): random_arch = random_architecture_func(xargs.max_nodes, search_space) mutate_arch = mutate_arch_func(search_space) #x =random_arch() ; y = mutate_arch(x) - if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): - logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) - nas_bench = None - else: - logger.log('{:} build NAS-Benchmark-API from {:}'.format(time_string(), xargs.arch_nas_dataset)) - nas_bench = AANASBenchAPI(xargs.arch_nas_dataset) logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) history = regularized_evolution(xargs.ea_cycles, xargs.ea_population, xargs.ea_sample_size, random_arch, mutate_arch, nas_bench if args.ea_fast_by_api else None, extra_info) logger.log('{:} regularized_evolution finish with history of {:} arch.'.format(time_string(), len(history))) @@ -199,13 +194,12 @@ def main(xargs): best_arch = best_arch.arch logger.log('{:} best arch is {:}'.format(time_string(), best_arch)) - if nas_bench is not None: - info = nas_bench.query_by_arch( best_arch ) - if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) - else : logger.log('{:}'.format(info)) + info = nas_bench.query_by_arch( best_arch ) + if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) + else : logger.log('{:}'.format(info)) logger.log('-'*100) - logger.close() + return logger.log_dir, nas_bench.query_index_by_arch( best_arch ) @@ -227,8 +221,23 @@ if __name__ == '__main__': parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).') parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') - parser.add_argument('--rand_seed', type=int, help='manual seed') + parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) + #if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) args.ea_fast_by_api = args.ea_fast_by_api > 0 - main(args) + + if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): + nas_bench = None + else: + print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset)) + nas_bench = AANASBenchAPI(args.arch_nas_dataset) + if args.rand_seed < 0: + save_dir, all_indexes, num = None, [], 500 + for i in range(num): + print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num)) + args.rand_seed = random.randint(1, 100000) + save_dir, index = main(args, nas_bench) + all_indexes.append( index ) + torch.save(all_indexes, save_dir / 'results.pth') + else: + main(args, nas_bench) diff --git a/exps/algos/reinforce.py b/exps/algos/reinforce.py index 73695ad..80a4e7b 100644 --- a/exps/algos/reinforce.py +++ b/exps/algos/reinforce.py @@ -89,7 +89,7 @@ def select_action(policy): return m.log_prob(action), action.cpu().tolist() -def main(xargs): +def main(xargs, nas_bench): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False @@ -128,12 +128,6 @@ def main(xargs): logger.log('eps : {:}'.format(eps)) # nas dataset load - if xargs.arch_nas_dataset is None or not os.path.isfile(xargs.arch_nas_dataset): - logger.log('Can not find the architecture dataset : {:}.'.format(xargs.arch_nas_dataset)) - nas_bench = None - else: - logger.log('{:} build NAS-Benchmark-API from {:}'.format(time_string(), xargs.arch_nas_dataset)) - nas_bench = AANASBenchAPI(xargs.arch_nas_dataset) logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) # REINFORCE @@ -156,13 +150,12 @@ def main(xargs): best_arch = policy.genotype() - if nas_bench is not None: - info = nas_bench.query_by_arch( best_arch ) - if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) - else : logger.log('{:}'.format(info)) + info = nas_bench.query_by_arch( best_arch ) + if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) + else : logger.log('{:}'.format(info)) logger.log('-'*100) - logger.close() + return logger.log_dir, nas_bench.query_index_by_arch( best_arch ) @@ -183,7 +176,21 @@ if __name__ == '__main__': parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).') parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') - parser.add_argument('--rand_seed', type=int, help='manual seed') + parser.add_argument('--rand_seed', type=int, default=-1, help='manual seed') args = parser.parse_args() - if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) - main(args) + #if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) + if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): + nas_bench = None + else: + print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset)) + nas_bench = AANASBenchAPI(args.arch_nas_dataset) + if args.rand_seed < 0: + save_dir, all_indexes, num = None, [], 500 + for i in range(num): + print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num)) + args.rand_seed = random.randint(1, 100000) + save_dir, index = main(args, nas_bench) + all_indexes.append( index ) + torch.save(all_indexes, save_dir / 'results.pth') + else: + main(args, nas_bench) diff --git a/lib/aa_nas_api/api.py b/lib/aa_nas_api/api.py index 3e4e351..db2a310 100644 --- a/lib/aa_nas_api/api.py +++ b/lib/aa_nas_api/api.py @@ -1,7 +1,7 @@ ################################################## # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # ################################################## -import os, sys, copy, torch, numpy as np +import os, sys, copy, random, torch, numpy as np from collections import OrderedDict @@ -149,7 +149,7 @@ class ArchResults(object): lantencies = [result.get_latency() for result in results] return np.mean(flops), np.mean(params), np.mean(lantencies) - def get_metrics(self, dataset, setname, iepoch=None): + def get_metrics(self, dataset, setname, iepoch=None, is_random=False): x_seeds = self.dataset_seed[dataset] results = [self.all_results[ (dataset, seed) ] for seed in x_seeds] loss, accuracy = [], [] @@ -160,7 +160,11 @@ class ArchResults(object): info = result.get_eval(setname, iepoch) loss.append( info['loss'] ) accuracy.append( info['accuracy'] ) - return float(np.mean(loss)), float(np.mean(accuracy)) + if is_random: + index = random.randint(0, len(loss)-1) + return loss[index], accuracy[index] + else: + return float(np.mean(loss)), float(np.mean(accuracy)) def show(self, is_print=False): return print_information(self, None, is_print) diff --git a/lib/models/__init__.py b/lib/models/__init__.py index d42305c..5acd8ec 100644 --- a/lib/models/__init__.py +++ b/lib/models/__init__.py @@ -16,10 +16,15 @@ from .cell_searchs import CellStructure, CellArchitectures # Cell-based NAS Models def get_cell_based_tiny_net(config): + super_type = getattr(config, 'super_type', 'basic') group_names = ['DARTS-V1', 'DARTS-V2', 'GDAS', 'SETN', 'ENAS', 'RANDOM'] - from .cell_searchs import nas_super_nets - if config.name in group_names: + if super_type == 'basic' and config.name in group_names: + from .cell_searchs import nas_super_nets return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space) + elif super_type == 'l2s-base' and config.name in group_names: + from .l2s_cell_searchs import nas_super_nets + return nas_super_nets[config.name](config.C, config.N, config.max_nodes, config.num_classes, config.space \ + ,config.n_piece) elif config.name == 'infer.tiny': from .cell_infers import TinyNetwork return TinyNetwork(config.C, config.N, config.genotype, config.num_classes) diff --git a/lib/models/cell_searchs/search_cells.py b/lib/models/cell_searchs/search_cells.py index 49124be..d510ba1 100644 --- a/lib/models/cell_searchs/search_cells.py +++ b/lib/models/cell_searchs/search_cells.py @@ -47,35 +47,17 @@ class SearchCell(nn.Module): return nodes[-1] # GDAS - def forward_gdas(self, inputs, alphas, _tau): - avoid_zero = 0 - while True: - gumbels = -torch.empty_like(alphas).exponential_().log() - logits = (alphas.log_softmax(dim=1) + gumbels) / _tau - probs = nn.functional.softmax(logits, dim=1) - index = probs.max(-1, keepdim=True)[1] - one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) - hardwts = one_h - probs.detach() + probs - if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): - continue # avoid the numerical error - nodes = [inputs] - for i in range(1, self.max_nodes): - inter_nodes = [] - for j in range(i): - node_str = '{:}<-{:}'.format(i, j) - weights = hardwts[ self.edge2index[node_str] ] - argmaxs = index[ self.edge2index[node_str] ].item() - weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) ) - inter_nodes.append( weigsum ) - nodes.append( sum(inter_nodes) ) - avoid_zero += 1 - if nodes[-1].sum().item() == 0: - if avoid_zero < 10: continue - else: - warnings.warn('get zero outputs with avoid_zero={:}'.format(avoid_zero)) - break - else: - break + def forward_gdas(self, inputs, hardwts, index): + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + weights = hardwts[ self.edge2index[node_str] ] + argmaxs = index[ self.edge2index[node_str] ].item() + weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) ) + inter_nodes.append( weigsum ) + nodes.append( sum(inter_nodes) ) return nodes[-1] # joint diff --git a/lib/models/cell_searchs/search_model_gdas.py b/lib/models/cell_searchs/search_model_gdas.py index 6a4dd4e..84ddcce 100644 --- a/lib/models/cell_searchs/search_model_gdas.py +++ b/lib/models/cell_searchs/search_model_gdas.py @@ -81,13 +81,21 @@ class TinyNetworkGDAS(nn.Module): return Structure( genotypes ) def forward(self, inputs): + while True: + gumbels = -torch.empty_like(self.arch_parameters).exponential_().log() + logits = (self.arch_parameters.log_softmax(dim=1) + gumbels) / self.tau + probs = nn.functional.softmax(logits, dim=1) + index = probs.max(-1, keepdim=True)[1] + one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) + hardwts = one_h - probs.detach() + probs + if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): continue + feature = self.stem(inputs) for i, cell in enumerate(self.cells): if isinstance(cell, SearchCell): - feature = cell.forward_gdas(feature, self.arch_parameters, self.tau) + feature = cell.forward_gdas(feature, hardwts, index) else: feature = cell(feature) - out = self.lastact(feature) out = self.global_pooling( out ) out = out.view(out.size(0), -1) diff --git a/lib/models/l2s_cell_searchs/__init__.py b/lib/models/l2s_cell_searchs/__init__.py new file mode 100644 index 0000000..2133795 --- /dev/null +++ b/lib/models/l2s_cell_searchs/__init__.py @@ -0,0 +1,17 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +################################################## +from .search_model_darts_v1 import TinyNetworkDartsV1 +from .search_model_darts_v2 import TinyNetworkDartsV2 +from .search_model_gdas import TinyNetworkGDAS +from .search_model_setn import TinyNetworkSETN +from .search_model_enas import TinyNetworkENAS +from .search_model_random import TinyNetworkRANDOM +from .genotypes import Structure as CellStructure, architectures as CellArchitectures + +nas_super_nets = {'DARTS-V1': TinyNetworkDartsV1, + 'DARTS-V2': TinyNetworkDartsV2, + 'GDAS' : TinyNetworkGDAS, + 'SETN' : TinyNetworkSETN, + 'ENAS' : TinyNetworkENAS, + 'RANDOM' : TinyNetworkRANDOM} diff --git a/lib/models/l2s_cell_searchs/_test_module.py b/lib/models/l2s_cell_searchs/_test_module.py new file mode 100644 index 0000000..c603ba6 --- /dev/null +++ b/lib/models/l2s_cell_searchs/_test_module.py @@ -0,0 +1,12 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +################################################## +import torch +from search_model_enas_utils import Controller + +def main(): + controller = Controller(6, 4) + predictions = controller() + +if __name__ == '__main__': + main() diff --git a/lib/models/l2s_cell_searchs/genotypes.py b/lib/models/l2s_cell_searchs/genotypes.py new file mode 100644 index 0000000..e0f2e2e --- /dev/null +++ b/lib/models/l2s_cell_searchs/genotypes.py @@ -0,0 +1,197 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +################################################## +from copy import deepcopy + + + +def get_combination(space, num): + combs = [] + for i in range(num): + if i == 0: + for func in space: + combs.append( [(func, i)] ) + else: + new_combs = [] + for string in combs: + for func in space: + xstring = string + [(func, i)] + new_combs.append( xstring ) + combs = new_combs + return combs + + + +class Structure: + + def __init__(self, genotype): + assert isinstance(genotype, list) or isinstance(genotype, tuple), 'invalid class of genotype : {:}'.format(type(genotype)) + self.node_num = len(genotype) + 1 + self.nodes = [] + self.node_N = [] + for idx, node_info in enumerate(genotype): + assert isinstance(node_info, list) or isinstance(node_info, tuple), 'invalid class of node_info : {:}'.format(type(node_info)) + assert len(node_info) >= 1, 'invalid length : {:}'.format(len(node_info)) + for node_in in node_info: + assert isinstance(node_in, list) or isinstance(node_in, tuple), 'invalid class of in-node : {:}'.format(type(node_in)) + assert len(node_in) == 2 and node_in[1] <= idx, 'invalid in-node : {:}'.format(node_in) + self.node_N.append( len(node_info) ) + self.nodes.append( tuple(deepcopy(node_info)) ) + + def tolist(self, remove_str): + # convert this class to the list, if remove_str is 'none', then remove the 'none' operation. + # note that we re-order the input node in this function + # return the-genotype-list and success [if unsuccess, it is not a connectivity] + genotypes = [] + for node_info in self.nodes: + node_info = list( node_info ) + node_info = sorted(node_info, key=lambda x: (x[1], x[0])) + node_info = tuple(filter(lambda x: x[0] != remove_str, node_info)) + if len(node_info) == 0: return None, False + genotypes.append( node_info ) + return genotypes, True + + def node(self, index): + assert index > 0 and index <= len(self), 'invalid index={:} < {:}'.format(index, len(self)) + return self.nodes[index] + + def tostr(self): + strings = [] + for node_info in self.nodes: + string = '|'.join([x[0]+'~{:}'.format(x[1]) for x in node_info]) + string = '|{:}|'.format(string) + strings.append( string ) + return '+'.join(strings) + + def check_valid(self): + nodes = {0: True} + for i, node_info in enumerate(self.nodes): + sums = [] + for op, xin in node_info: + if op == 'none' or nodes[xin] == False: x = False + else: x = True + sums.append( x ) + nodes[i+1] = sum(sums) > 0 + return nodes[len(self.nodes)] + + def to_unique_str(self, consider_zero=False): + # this is used to identify the isomorphic cell, which rerquires the prior knowledge of operation + # two operations are special, i.e., none and skip_connect + nodes = {0: '0'} + for i_node, node_info in enumerate(self.nodes): + cur_node = [] + for op, xin in node_info: + if consider_zero: + if op == 'none' or nodes[xin] == '#': x = '#' # zero + elif op == 'skip_connect': x = nodes[xin] + else: x = '('+nodes[xin]+')' + '@{:}'.format(op) + else: + if op == 'skip_connect': x = nodes[xin] + else: x = '('+nodes[xin]+')' + '@{:}'.format(op) + cur_node.append(x) + nodes[i_node+1] = '+'.join( sorted(cur_node) ) + return nodes[ len(self.nodes) ] + + def check_valid_op(self, op_names): + for node_info in self.nodes: + for inode_edge in node_info: + #assert inode_edge[0] in op_names, 'invalid op-name : {:}'.format(inode_edge[0]) + if inode_edge[0] not in op_names: return False + return True + + def __repr__(self): + return ('{name}({node_num} nodes with {node_info})'.format(name=self.__class__.__name__, node_info=self.tostr(), **self.__dict__)) + + def __len__(self): + return len(self.nodes) + 1 + + def __getitem__(self, index): + return self.nodes[index] + + @staticmethod + def str2structure(xstr): + assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) + nodestrs = xstr.split('+') + genotypes = [] + for i, node_str in enumerate(nodestrs): + inputs = list(filter(lambda x: x != '', node_str.split('|'))) + for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) + inputs = ( xi.split('~') for xi in inputs ) + input_infos = tuple( (op, int(IDX)) for (op, IDX) in inputs) + genotypes.append( input_infos ) + return Structure( genotypes ) + + @staticmethod + def str2fullstructure(xstr, default_name='none'): + assert isinstance(xstr, str), 'must take string (not {:}) as input'.format(type(xstr)) + nodestrs = xstr.split('+') + genotypes = [] + for i, node_str in enumerate(nodestrs): + inputs = list(filter(lambda x: x != '', node_str.split('|'))) + for xinput in inputs: assert len(xinput.split('~')) == 2, 'invalid input length : {:}'.format(xinput) + inputs = ( xi.split('~') for xi in inputs ) + input_infos = list( (op, int(IDX)) for (op, IDX) in inputs) + all_in_nodes= list(x[1] for x in input_infos) + for j in range(i): + if j not in all_in_nodes: input_infos.append((default_name, j)) + node_info = sorted(input_infos, key=lambda x: (x[1], x[0])) + genotypes.append( tuple(node_info) ) + return Structure( genotypes ) + + @staticmethod + def gen_all(search_space, num, return_ori): + assert isinstance(search_space, list) or isinstance(search_space, tuple), 'invalid class of search-space : {:}'.format(type(search_space)) + assert num >= 2, 'There should be at least two nodes in a neural cell instead of {:}'.format(num) + all_archs = get_combination(search_space, 1) + for i, arch in enumerate(all_archs): + all_archs[i] = [ tuple(arch) ] + + for inode in range(2, num): + cur_nodes = get_combination(search_space, inode) + new_all_archs = [] + for previous_arch in all_archs: + for cur_node in cur_nodes: + new_all_archs.append( previous_arch + [tuple(cur_node)] ) + all_archs = new_all_archs + if return_ori: + return all_archs + else: + return [Structure(x) for x in all_archs] + + + +ResNet_CODE = Structure( + [(('nor_conv_3x3', 0), ), # node-1 + (('nor_conv_3x3', 1), ), # node-2 + (('skip_connect', 0), ('skip_connect', 2))] # node-3 + ) + +AllConv3x3_CODE = Structure( + [(('nor_conv_3x3', 0), ), # node-1 + (('nor_conv_3x3', 0), ('nor_conv_3x3', 1)), # node-2 + (('nor_conv_3x3', 0), ('nor_conv_3x3', 1), ('nor_conv_3x3', 2))] # node-3 + ) + +AllFull_CODE = Structure( + [(('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0)), # node-1 + (('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1)), # node-2 + (('skip_connect', 0), ('nor_conv_1x1', 0), ('nor_conv_3x3', 0), ('avg_pool_3x3', 0), ('skip_connect', 1), ('nor_conv_1x1', 1), ('nor_conv_3x3', 1), ('avg_pool_3x3', 1), ('skip_connect', 2), ('nor_conv_1x1', 2), ('nor_conv_3x3', 2), ('avg_pool_3x3', 2))] # node-3 + ) + +AllConv1x1_CODE = Structure( + [(('nor_conv_1x1', 0), ), # node-1 + (('nor_conv_1x1', 0), ('nor_conv_1x1', 1)), # node-2 + (('nor_conv_1x1', 0), ('nor_conv_1x1', 1), ('nor_conv_1x1', 2))] # node-3 + ) + +AllIdentity_CODE = Structure( + [(('skip_connect', 0), ), # node-1 + (('skip_connect', 0), ('skip_connect', 1)), # node-2 + (('skip_connect', 0), ('skip_connect', 1), ('skip_connect', 2))] # node-3 + ) + +architectures = {'resnet' : ResNet_CODE, + 'all_c3x3': AllConv3x3_CODE, + 'all_c1x1': AllConv1x1_CODE, + 'all_idnt': AllIdentity_CODE, + 'all_full': AllFull_CODE} diff --git a/lib/models/l2s_cell_searchs/search_cells.py b/lib/models/l2s_cell_searchs/search_cells.py new file mode 100644 index 0000000..fba750f --- /dev/null +++ b/lib/models/l2s_cell_searchs/search_cells.py @@ -0,0 +1,148 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +################################################## +import math, random, torch +import warnings +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +from ..cell_operations import OPS + + +class SearchCell(nn.Module): + + def __init__(self, C_in, C_out, stride, max_nodes, op_names, n_piece): + super(SearchCell, self).__init__() + + self.op_names = deepcopy(op_names) + self.max_nodes = max_nodes + self.in_dim = C_in + self.out_dim = C_out + self.n_piece = n_piece + self.multi_edges = nn.ModuleList() + for i_piece in range(n_piece): + edges = nn.ModuleDict() + for i in range(1, max_nodes): + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + if j == 0: xlists = [OPS[op_name](C_in , C_out, stride) for op_name in op_names] + else : xlists = [OPS[op_name](C_in , C_out, 1) for op_name in op_names] + edges[ node_str ] = nn.ModuleList( xlists ) + self.multi_edges.append( edges ) + + self.edge_keys = sorted(list(edges.keys())) + self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} + self.num_edges = len(edges) + + def extra_repr(self): + string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}, nP={n_piece}'.format(**self.__dict__) + return string + + def forward(self, inputs, weightss): + nodes = [inputs] + with torch.no_grad(): + xmod, xid, argmax = 1, 0, weightss.argmax(dim=1).cpu().tolist() + for i, x in enumerate(argmax): + xid += x * (xmod % self.n_piece) + xmod = (xmod * len(self.op_names)) % self.n_piece + xid = xid % self.n_piece + edges = self.multi_edges[xid] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + weights = weightss[ self.edge2index[node_str] ] + inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(edges[node_str], weights) ) ) + nodes.append( sum(inter_nodes) ) + return nodes[-1] + + # GDAS + def forward_gdas(self, inputs, alphas, _tau): + avoid_zero = 0 + while True: + gumbels = -torch.empty_like(alphas).exponential_().log() + logits = (alphas.log_softmax(dim=1) + gumbels) / _tau + probs = nn.functional.softmax(logits, dim=1) + index = probs.max(-1, keepdim=True)[1] + one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) + hardwts = one_h - probs.detach() + probs + if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): + continue # avoid the numerical error + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + weights = hardwts[ self.edge2index[node_str] ] + argmaxs = index[ self.edge2index[node_str] ].item() + weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) ) + inter_nodes.append( weigsum ) + nodes.append( sum(inter_nodes) ) + avoid_zero += 1 + if nodes[-1].sum().item() == 0: + if avoid_zero < 10: continue + else: + warnings.warn('get zero outputs with avoid_zero={:}'.format(avoid_zero)) + break + else: + break + return nodes[-1] + + # joint + def forward_joint(self, inputs, weightss): + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + weights = weightss[ self.edge2index[node_str] ] + #aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel() + aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) + inter_nodes.append( aggregation ) + nodes.append( sum(inter_nodes) ) + return nodes[-1] + + # uniform random sampling per iteration + def forward_urs(self, inputs): + nodes = [inputs] + for i in range(1, self.max_nodes): + while True: # to avoid select zero for all ops + sops, has_non_zero = [], False + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + candidates = self.edges[node_str] + select_op = random.choice(candidates) + sops.append( select_op ) + if not hasattr(select_op, 'is_zero') or select_op.is_zero == False: has_non_zero=True + if has_non_zero: break + inter_nodes = [] + for j, select_op in enumerate(sops): + inter_nodes.append( select_op(nodes[j]) ) + nodes.append( sum(inter_nodes) ) + return nodes[-1] + + # select the argmax + def forward_select(self, inputs, weightss): + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + weights = weightss[ self.edge2index[node_str] ] + inter_nodes.append( self.edges[node_str][ weights.argmax().item() ]( nodes[j] ) ) + #inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) ) + nodes.append( sum(inter_nodes) ) + return nodes[-1] + + # forward with a specific structure + def forward_dynamic(self, inputs, structure): + nodes = [inputs] + for i in range(1, self.max_nodes): + cur_op_node = structure.nodes[i-1] + inter_nodes = [] + for op_name, j in cur_op_node: + node_str = '{:}<-{:}'.format(i, j) + op_index = self.op_names.index( op_name ) + inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) ) + nodes.append( sum(inter_nodes) ) + return nodes[-1] diff --git a/lib/models/l2s_cell_searchs/search_model_darts_v1.py b/lib/models/l2s_cell_searchs/search_model_darts_v1.py new file mode 100644 index 0000000..ffc381e --- /dev/null +++ b/lib/models/l2s_cell_searchs/search_model_darts_v1.py @@ -0,0 +1,93 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +######################################################## +# DARTS: Differentiable Architecture Search, ICLR 2019 # +######################################################## +import torch +import torch.nn as nn +from copy import deepcopy +from ..cell_operations import ResNetBasicblock +from .search_cells import SearchCell +from .genotypes import Structure + + +class TinyNetworkDartsV1(nn.Module): + + def __init__(self, C, N, max_nodes, num_classes, search_space, n_piece): + super(TinyNetworkDartsV1, self).__init__() + self._C = C + self._layerN = N + self.max_nodes = max_nodes + self.stem = nn.Sequential( + nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(C)) + + layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + + C_prev, num_edge, edge2index = C, None, None + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2) + else: + cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, n_piece) + if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index + else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) + self.cells.append( cell ) + C_prev = cell.out_dim + self.op_names = deepcopy( search_space ) + self._Layer = len(self.cells) + self.edge2index = edge2index + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) + + def get_weights(self): + xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) + xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) + xlist+= list( self.classifier.parameters() ) + return xlist + + def get_alphas(self): + return [self.arch_parameters] + + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) + return string + + def extra_repr(self): + return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) + + def genotype(self): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + with torch.no_grad(): + weights = self.arch_parameters[ self.edge2index[node_str] ] + op_name = self.op_names[ weights.argmax().item() ] + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + return Structure( genotypes ) + + def forward(self, inputs): + alphas = nn.functional.softmax(self.arch_parameters, dim=-1) + + feature = self.stem(inputs) + for i, cell in enumerate(self.cells): + if isinstance(cell, SearchCell): + feature = cell(feature, alphas) + else: + feature = cell(feature) + + out = self.lastact(feature) + out = self.global_pooling( out ) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return out, logits diff --git a/lib/models/l2s_cell_searchs/search_model_darts_v2.py b/lib/models/l2s_cell_searchs/search_model_darts_v2.py new file mode 100644 index 0000000..cb996ff --- /dev/null +++ b/lib/models/l2s_cell_searchs/search_model_darts_v2.py @@ -0,0 +1,93 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +######################################################## +# DARTS: Differentiable Architecture Search, ICLR 2019 # +######################################################## +import torch +import torch.nn as nn +from copy import deepcopy +from ..cell_operations import ResNetBasicblock +from .search_cells import SearchCell +from .genotypes import Structure + + +class TinyNetworkDartsV2(nn.Module): + + def __init__(self, C, N, max_nodes, num_classes, search_space): + super(TinyNetworkDartsV2, self).__init__() + self._C = C + self._layerN = N + self.max_nodes = max_nodes + self.stem = nn.Sequential( + nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(C)) + + layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + + C_prev, num_edge, edge2index = C, None, None + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2) + else: + cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) + if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index + else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) + self.cells.append( cell ) + C_prev = cell.out_dim + self.op_names = deepcopy( search_space ) + self._Layer = len(self.cells) + self.edge2index = edge2index + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) + + def get_weights(self): + xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) + xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) + xlist+= list( self.classifier.parameters() ) + return xlist + + def get_alphas(self): + return [self.arch_parameters] + + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) + return string + + def extra_repr(self): + return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) + + def genotype(self): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + with torch.no_grad(): + weights = self.arch_parameters[ self.edge2index[node_str] ] + op_name = self.op_names[ weights.argmax().item() ] + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + return Structure( genotypes ) + + def forward(self, inputs): + alphas = nn.functional.softmax(self.arch_parameters, dim=-1) + + feature = self.stem(inputs) + for i, cell in enumerate(self.cells): + if isinstance(cell, SearchCell): + feature = cell(feature, alphas) + else: + feature = cell(feature) + + out = self.lastact(feature) + out = self.global_pooling( out ) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return out, logits diff --git a/lib/models/l2s_cell_searchs/search_model_enas.py b/lib/models/l2s_cell_searchs/search_model_enas.py new file mode 100644 index 0000000..2422b52 --- /dev/null +++ b/lib/models/l2s_cell_searchs/search_model_enas.py @@ -0,0 +1,94 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +########################################################################## +# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 # +########################################################################## +import torch +import torch.nn as nn +from copy import deepcopy +from ..cell_operations import ResNetBasicblock +from .search_cells import SearchCell +from .genotypes import Structure +from .search_model_enas_utils import Controller + + +class TinyNetworkENAS(nn.Module): + + def __init__(self, C, N, max_nodes, num_classes, search_space): + super(TinyNetworkENAS, self).__init__() + self._C = C + self._layerN = N + self.max_nodes = max_nodes + self.stem = nn.Sequential( + nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(C)) + + layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + + C_prev, num_edge, edge2index = C, None, None + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2) + else: + cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) + if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index + else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) + self.cells.append( cell ) + C_prev = cell.out_dim + self.op_names = deepcopy( search_space ) + self._Layer = len(self.cells) + self.edge2index = edge2index + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + # to maintain the sampled architecture + self.sampled_arch = None + + def update_arch(self, _arch): + if _arch is None: + self.sampled_arch = None + elif isinstance(_arch, Structure): + self.sampled_arch = _arch + elif isinstance(_arch, (list, tuple)): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + op_index = _arch[ self.edge2index[node_str] ] + op_name = self.op_names[ op_index ] + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + self.sampled_arch = Structure(genotypes) + else: + raise ValueError('invalid type of input architecture : {:}'.format(_arch)) + return self.sampled_arch + + def create_controller(self): + return Controller(len(self.edge2index), len(self.op_names)) + + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) + return string + + def extra_repr(self): + return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) + + def forward(self, inputs): + + feature = self.stem(inputs) + for i, cell in enumerate(self.cells): + if isinstance(cell, SearchCell): + feature = cell.forward_dynamic(feature, self.sampled_arch) + else: feature = cell(feature) + + out = self.lastact(feature) + out = self.global_pooling( out ) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return out, logits diff --git a/lib/models/l2s_cell_searchs/search_model_enas_utils.py b/lib/models/l2s_cell_searchs/search_model_enas_utils.py new file mode 100644 index 0000000..e03f57b --- /dev/null +++ b/lib/models/l2s_cell_searchs/search_model_enas_utils.py @@ -0,0 +1,55 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +########################################################################## +# Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 # +########################################################################## +import torch +import torch.nn as nn +from torch.distributions.categorical import Categorical + +class Controller(nn.Module): + # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py + def __init__(self, num_edge, num_ops, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0): + super(Controller, self).__init__() + # assign the attributes + self.num_edge = num_edge + self.num_ops = num_ops + self.lstm_size = lstm_size + self.lstm_N = lstm_num_layers + self.tanh_constant = tanh_constant + self.temperature = temperature + # create parameters + self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size))) + self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N) + self.w_embd = nn.Embedding(self.num_ops, self.lstm_size) + self.w_pred = nn.Linear(self.lstm_size, self.num_ops) + + nn.init.uniform_(self.input_vars , -0.1, 0.1) + nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1) + nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1) + nn.init.uniform_(self.w_embd.weight , -0.1, 0.1) + nn.init.uniform_(self.w_pred.weight , -0.1, 0.1) + + def forward(self): + + inputs, h0 = self.input_vars, None + log_probs, entropys, sampled_arch = [], [], [] + for iedge in range(self.num_edge): + outputs, h0 = self.w_lstm(inputs, h0) + + logits = self.w_pred(outputs) + logits = logits / self.temperature + logits = self.tanh_constant * torch.tanh(logits) + # distribution + op_distribution = Categorical(logits=logits) + op_index = op_distribution.sample() + sampled_arch.append( op_index.item() ) + + op_log_prob = op_distribution.log_prob(op_index) + log_probs.append( op_log_prob.view(-1) ) + op_entropy = op_distribution.entropy() + entropys.append( op_entropy.view(-1) ) + + # obtain the input embedding for the next step + inputs = self.w_embd(op_index) + return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), sampled_arch diff --git a/lib/models/l2s_cell_searchs/search_model_gdas.py b/lib/models/l2s_cell_searchs/search_model_gdas.py new file mode 100644 index 0000000..6a4dd4e --- /dev/null +++ b/lib/models/l2s_cell_searchs/search_model_gdas.py @@ -0,0 +1,96 @@ +########################################################################### +# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # +########################################################################### +import torch +import torch.nn as nn +from copy import deepcopy +from ..cell_operations import ResNetBasicblock +from .search_cells import SearchCell +from .genotypes import Structure + + +class TinyNetworkGDAS(nn.Module): + + def __init__(self, C, N, max_nodes, num_classes, search_space): + super(TinyNetworkGDAS, self).__init__() + self._C = C + self._layerN = N + self.max_nodes = max_nodes + self.stem = nn.Sequential( + nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(C)) + + layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + + C_prev, num_edge, edge2index = C, None, None + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2) + else: + cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) + if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index + else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) + self.cells.append( cell ) + C_prev = cell.out_dim + self.op_names = deepcopy( search_space ) + self._Layer = len(self.cells) + self.edge2index = edge2index + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) + self.tau = 10 + + def get_weights(self): + xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) + xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) + xlist+= list( self.classifier.parameters() ) + return xlist + + def set_tau(self, tau): + self.tau = tau + + def get_tau(self): + return self.tau + + def get_alphas(self): + return [self.arch_parameters] + + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) + return string + + def extra_repr(self): + return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) + + def genotype(self): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + with torch.no_grad(): + weights = self.arch_parameters[ self.edge2index[node_str] ] + op_name = self.op_names[ weights.argmax().item() ] + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + return Structure( genotypes ) + + def forward(self, inputs): + feature = self.stem(inputs) + for i, cell in enumerate(self.cells): + if isinstance(cell, SearchCell): + feature = cell.forward_gdas(feature, self.arch_parameters, self.tau) + else: + feature = cell(feature) + + out = self.lastact(feature) + out = self.global_pooling( out ) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return out, logits diff --git a/lib/models/l2s_cell_searchs/search_model_random.py b/lib/models/l2s_cell_searchs/search_model_random.py new file mode 100644 index 0000000..c2f83f9 --- /dev/null +++ b/lib/models/l2s_cell_searchs/search_model_random.py @@ -0,0 +1,81 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +############################################################################## +# Random Search and Reproducibility for Neural Architecture Search, UAI 2019 # +############################################################################## +import torch, random +import torch.nn as nn +from copy import deepcopy +from ..cell_operations import ResNetBasicblock +from .search_cells import SearchCell +from .genotypes import Structure + + +class TinyNetworkRANDOM(nn.Module): + + def __init__(self, C, N, max_nodes, num_classes, search_space): + super(TinyNetworkRANDOM, self).__init__() + self._C = C + self._layerN = N + self.max_nodes = max_nodes + self.stem = nn.Sequential( + nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(C)) + + layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + + C_prev, num_edge, edge2index = C, None, None + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2) + else: + cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) + if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index + else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) + self.cells.append( cell ) + C_prev = cell.out_dim + self.op_names = deepcopy( search_space ) + self._Layer = len(self.cells) + self.edge2index = edge2index + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + self.arch_cache = None + + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) + return string + + def extra_repr(self): + return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) + + def random_genotype(self, set_cache): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + op_name = random.choice( self.op_names ) + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + arch = Structure( genotypes ) + if set_cache: self.arch_cache = arch + return arch + + def forward(self, inputs): + + feature = self.stem(inputs) + for i, cell in enumerate(self.cells): + if isinstance(cell, SearchCell): + feature = cell.forward_dynamic(feature, self.arch_cache) + else: feature = cell(feature) + + out = self.lastact(feature) + out = self.global_pooling( out ) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + return out, logits diff --git a/lib/models/l2s_cell_searchs/search_model_setn.py b/lib/models/l2s_cell_searchs/search_model_setn.py new file mode 100644 index 0000000..5864f32 --- /dev/null +++ b/lib/models/l2s_cell_searchs/search_model_setn.py @@ -0,0 +1,152 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +###################################################################################### +# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 # +###################################################################################### +import torch, random +import torch.nn as nn +from copy import deepcopy +from ..cell_operations import ResNetBasicblock +from .search_cells import SearchCell +from .genotypes import Structure + + +class TinyNetworkSETN(nn.Module): + + def __init__(self, C, N, max_nodes, num_classes, search_space): + super(TinyNetworkSETN, self).__init__() + self._C = C + self._layerN = N + self.max_nodes = max_nodes + self.stem = nn.Sequential( + nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(C)) + + layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + + C_prev, num_edge, edge2index = C, None, None + self.cells = nn.ModuleList() + for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2) + else: + cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space) + if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index + else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) + self.cells.append( cell ) + C_prev = cell.out_dim + self.op_names = deepcopy( search_space ) + self._Layer = len(self.cells) + self.edge2index = edge2index + self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) + self.mode = 'urs' + self.dynamic_cell = None + + def set_cal_mode(self, mode, dynamic_cell=None): + assert mode in ['urs', 'joint', 'select', 'dynamic'] + self.mode = mode + if mode == 'dynamic': self.dynamic_cell = deepcopy( dynamic_cell ) + else : self.dynamic_cell = None + + def get_cal_mode(self): + return self.mode + + def get_weights(self): + xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) + xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) + xlist+= list( self.classifier.parameters() ) + return xlist + + def get_alphas(self): + return [self.arch_parameters] + + def get_message(self): + string = self.extra_repr() + for i, cell in enumerate(self.cells): + string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) + return string + + def extra_repr(self): + return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) + + def genotype(self): + genotypes = [] + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + with torch.no_grad(): + weights = self.arch_parameters[ self.edge2index[node_str] ] + op_name = self.op_names[ weights.argmax().item() ] + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + return Structure( genotypes ) + + def dync_genotype(self, use_random=False): + genotypes = [] + with torch.no_grad(): + alphas_cpu = nn.functional.softmax(self.arch_parameters, dim=-1) + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + if use_random: + op_name = random.choice(self.op_names) + else: + weights = alphas_cpu[ self.edge2index[node_str] ] + op_index = torch.multinomial(weights, 1).item() + op_name = self.op_names[ op_index ] + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + return Structure( genotypes ) + + def get_log_prob(self, arch): + with torch.no_grad(): + logits = nn.functional.log_softmax(self.arch_parameters, dim=-1) + select_logits = [] + for i, node_info in enumerate(arch.nodes): + for op, xin in node_info: + node_str = '{:}<-{:}'.format(i+1, xin) + op_index = self.op_names.index(op) + select_logits.append( logits[self.edge2index[node_str], op_index] ) + return sum(select_logits).item() + + + def return_topK(self, K): + archs = Structure.gen_all(self.op_names, self.max_nodes, False) + pairs = [(self.get_log_prob(arch), arch) for arch in archs] + if K < 0 or K >= len(archs): K = len(archs) + sorted_pairs = sorted(pairs, key=lambda x: -x[0]) + return_pairs = [sorted_pairs[_][1] for _ in range(K)] + return return_pairs + + + def forward(self, inputs): + alphas = nn.functional.softmax(self.arch_parameters, dim=-1) + with torch.no_grad(): + alphas_cpu = alphas.detach().cpu() + + feature = self.stem(inputs) + for i, cell in enumerate(self.cells): + if isinstance(cell, SearchCell): + if self.mode == 'urs': + feature = cell.forward_urs(feature) + elif self.mode == 'select': + feature = cell.forward_select(feature, alphas_cpu) + elif self.mode == 'joint': + feature = cell.forward_joint(feature, alphas) + elif self.mode == 'dynamic': + feature = cell.forward_dynamic(feature, self.dynamic_cell) + else: raise ValueError('invalid mode={:}'.format(self.mode)) + else: feature = cell(feature) + + out = self.lastact(feature) + out = self.global_pooling( out ) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return out, logits