import numpy as np import copy import itertools import random import sys import os import pickle import torch from torch.nn.functional import one_hot # from ofa.imagenet_classification.run_manager import RunManager # from naszilla.nas_bench_201.distances import * # INPUT = 'input' # OUTPUT = 'output' # OPS = ['avg_pool_3x3', 'nor_conv_1x1', 'nor_conv_3x3', 'none', 'skip_connect'] # NUM_OPS = len(OPS) # OP_SPOTS = 20 KS_LIST = [3, 5, 7] EXPAND_LIST = [3, 4, 6] DEPTH_LIST = [2, 3, 4] NUM_STAGE = 5 MAX_LAYER_PER_STAGE = 4 MAX_N_BLOCK= NUM_STAGE * MAX_LAYER_PER_STAGE # 20 OPS = { '3-3': 0, '3-4': 1, '3-6': 2, '5-3': 3, '5-4': 4, '5-6': 5, '7-3': 6, '7-4': 7, '7-6': 8, } OPS2STR = { 0: '3-3', 1: '3-4', 2: '3-6', 3: '5-3', 4: '5-4', 5: '5-6', 6: '7-3', 7: '7-4', 8: '7-6', } NUM_OPS = len(OPS) LONGEST_PATH_LENGTH = 20 # OP_SPOTS = NUM_VERTICES - 2 # OPS = { # '3-3': 0, '3-4': 1, '3-6': 2, # '5-3': 3, '5-4': 4, '5-6': 5, # '7-3': 6, '7-4': 7, '7-6': 8, # } # OFA evolution hyper-parameters # self.arch_mutate_prob = kwargs.get("arch_mutate_prob", 0.1) # self.resolution_mutate_prob = kwargs.get("resolution_mutate_prob", 0.5) # self.population_size = kwargs.get("population_size", 100) # self.max_time_budget = kwargs.get("max_time_budget", 500) # self.parent_ratio = kwargs.get("parent_ratio", 0.25) # self.mutation_ratio = kwargs.get("mutation_ratio", 0.5) class OFASubNet: def __init__(self, string, accuracy_predictor=None): self.string = string self.accuracy_predictor = accuracy_predictor # self.run_config = run_config def get_string(self): return self.string def serialize(self): return { 'string':self.string } @classmethod def random_cell(cls, nasbench, max_nodes=None, max_edges=None, cutoff=None, index_hash=None, random_encoding=None): """ OFA sample random subnet """ # Randomly sample sub-networks from OFA network random_subnet_config = nasbench.sample_active_subnet() #{ # "ks": ks_setting, # "e": expand_setting, # "d": depth_setting, # } return {'string':cls.get_string_from_ops(random_subnet_config)} def encode(self, predictor_encoding, nasbench=None, deterministic=True, cutoff=None, nasbench_ours=None, dataset=None): if predictor_encoding == 'adj': return self.encode_standard() elif predictor_encoding == 'path': raise NotImplementedError return self.encode_paths() elif predictor_encoding == 'trunc_path': if not cutoff: cutoff = 30 dic = self.gcn_encoding(nasbench, deterministic=deterministic, nasbench_ours=nasbench_ours, dataset=dataset) dic['trunc_path'] = self.encode_freq_paths(cutoff=cutoff) return dic # return self.encode_freq_paths(cutoff=cutoff) elif predictor_encoding == 'gcn': return self.gcn_encoding(nasbench, deterministic=deterministic, nasbench_ours=nasbench_ours, dataset=dataset) else: print('{} is an invalid predictor encoding'.format(predictor_encoding)) raise NotImplementedError() def get_ops_onehot(self): ops = self.get_op_dict() # ops = [INPUT, *ops, OUTPUT] node_types = torch.zeros(NUM_STAGE * MAX_LAYER_PER_STAGE).long() # w/o in / out num_vertices = len(OPS.values()) num_nodes = NUM_STAGE * MAX_LAYER_PER_STAGE d_matrix = [] # import pdb; pdb.set_trace() for i in range(NUM_STAGE): ds = ops['d'][i] for j in range(ds): d_matrix.append(ds) for j in range(MAX_LAYER_PER_STAGE - ds): d_matrix.append('none') for i, (ks, e, d) in enumerate(zip( ops['ks'], ops['e'], d_matrix)): if d == 'none': # node_types[i] = OPS[d] pass else: node_types[i] = OPS[f'{ks}-{e}'] ops_onehot = one_hot(node_types, num_vertices).float() return ops_onehot def gcn_encoding(self, nasbench, deterministic, nasbench_ours=None, dataset=None): # op_map = [OUTPUT, INPUT, *OPS] ops = self.get_op_dict() # ops = [INPUT, *ops, OUTPUT] node_types = torch.zeros(NUM_STAGE * MAX_LAYER_PER_STAGE).long() # w/o in / out num_vertices = len(OPS.values()) num_nodes = NUM_STAGE * MAX_LAYER_PER_STAGE d_matrix = [] # import pdb; pdb.set_trace() for i in range(NUM_STAGE): ds = ops['d'][i] for j in range(ds): d_matrix.append(ds) for j in range(MAX_LAYER_PER_STAGE - ds): d_matrix.append('none') for i, (ks, e, d) in enumerate(zip( ops['ks'], ops['e'], d_matrix)): if d == 'none': # node_types[i] = OPS[d] pass else: node_types[i] = OPS[f'{ks}-{e}'] ops_onehot = one_hot(node_types, num_vertices).float() val_loss = self.get_val_loss(nasbench, dataset=dataset) test_loss = copy.deepcopy(val_loss) # (num node, ops types) --> (20, 28) def get_adj(): adj = torch.zeros(num_nodes, num_nodes) for i in range(num_nodes-1): adj[i, i+1] = 1 adj = np.array(adj) return adj matrix = get_adj() dic = { 'num_vertices': num_vertices, 'adjacency': matrix, 'operations': ops_onehot, 'mask': np.array([i < num_vertices for i in range(num_vertices)], dtype=np.float32), 'val_acc': 1.0 - val_loss, 'test_acc': 1.0 - test_loss, 'x': ops_onehot } return dic def get_runtime(self, nasbench, dataset='cifar100'): return nasbench.query_by_index(index, dataset).get_eval('x-valid')['time'] def get_val_loss(self, nasbench, deterministic=1, dataset='cifar100'): assert dataset == 'imagenet1k' # SuperNet version # ops = self.get_op_dict() # nasbench.set_active_subnet(ks=ops['ks'], e=ops['e'], d=ops['d']) # subnet = nasbench.get_active_subnet(preserve_weight=True) # run_manager = RunManager(".tmp/eval_subnet", subnet, self.run_config, init=False) # # assign image size: 128, 132, ..., 224 # self.run_config.data_provider.assign_active_img_size(224) # run_manager.reset_running_statistics(net=subnet) # loss, (top1, top5) = run_manager.validate(net=subnet) # # print("Results: loss=%.5f,\t top1=%.1f,\t top5=%.1f" % (loss, top1, top5)) # self.loss = loss # self.top1 = top1 # self.top5 = top5 # accuracy predictor version ops = self.get_op_dict() # resolutions = [160, 176, 192, 208, 224] # ops['r'] = [random.choice(resolutions)] ops['r'] = [224] acc = self.accuracy_predictor.predict_accuracy([ops])[0][0].item() return 1.0 - acc def get_test_loss(self, nasbench, dataset='cifar100', deterministic=1): ops = self.get_op_dict() # resolutions = [160, 176, 192, 208, 224] # ops['r'] = [random.choice(resolutions)] ops['r'] = [224] acc = self.accuracy_predictor.predict_accuracy([ops])[0][0].item() return 1.0 - acc def get_op_dict(self): # given a string, get the list of operations ops = { "ks": [], "e": [], "d": [] } tokens = self.string.split('_') for i, token in enumerate(tokens): d, ks, e = token.split('-') if i % MAX_LAYER_PER_STAGE == 0: ops['d'].append(int(d)) ops['ks'].append(int(ks)) ops['e'].append(int(e)) return ops def get_num(self): # compute the unique number of the architecture, in [0, 15624] ops = self.get_op_dict() index = 0 for i, op in enumerate(ops): index += OPS.index(op) * NUM_OPS ** i return index def get_random_hash(self): num = self.get_num() hashes = pickle.load(open('nas_bench_201/random_hash.pkl', 'rb')) return hashes[num] @classmethod def get_string_from_ops(cls, ops): string = '' for i, (ks, e) in enumerate(zip(ops['ks'], ops['e'])): d = ops['d'][int(i/MAX_LAYER_PER_STAGE)] string += f'{d}-{ks}-{e}_' return string[:-1] def perturb(self, nasbench, mutation_rate=1): # deterministic version of mutate ops = self.get_op_dict() new_ops = [] num = np.random.choice(len(ops)) for i, op in enumerate(ops): if i == num: available = [o for o in OPS if o != op] new_ops.append(np.random.choice(available)) else: new_ops.append(op) return {'string':self.get_string_from_ops(new_ops)} def mutate(self, nasbench, mutation_rate=0.1, mutate_encoding='adj', index_hash=None, cutoff=30, patience=5000): p = 0 mutation_rate = mutation_rate / 10 # OFA rate: 0.1 arch_dict = self.get_op_dict() if mutate_encoding == 'adj': # OFA version mutation # https://github.com/mit-han-lab/once-for-all/blob/master/ofa/nas/search_algorithm/evolution.py for i in range(MAX_N_BLOCK): if random.random() < mutation_rate: available_ks = [ks for ks in KS_LIST if ks != arch_dict["ks"][i]] available_e = [e for e in EXPAND_LIST if e != arch_dict["e"][i]] arch_dict["ks"][i] = random.choice(available_ks) arch_dict["e"][i] = random.choice(available_e) for i in range(NUM_STAGE): if random.random() < mutation_rate: available_d = [d for d in DEPTH_LIST if d != arch_dict["d"][i]] arch_dict["d"][i] = random.choice(available_d) return {'string':self.get_string_from_ops(arch_dict)} elif mutate_encoding in ['path', 'trunc_path']: raise NotImplementedError() else: print('{} is an invalid mutate encoding'.format(mutate_encoding)) raise NotImplementedError() def encode_standard(self): """ compute the standard encoding """ ops = self.get_op_dict() encoding = [] for i, (ks, e) in enumerate(zip(ops['ks'], ops['e'])): string = f'{ks}-{e}' encoding.append(OPS[string]) return encoding def encode_one_hot(self): """ compute the one-hot encoding """ encoding = self.encode_standard() one_hot = [] for num in encoding: for i in range(len(OPS)): if i == num: one_hot.append(1) else: one_hot.append(0) return one_hot def get_num_params(self, nasbench): # todo: add this method return 100 def get_paths(self): """ return all paths from input to output """ path_blueprints = [[3], [0,4], [1,5], [0,2,5]] ops = self.get_op_dict() paths = [] for blueprint in path_blueprints: paths.append([ops[node] for node in blueprint]) return paths def get_path_indices(self): """ compute the index of each path """ paths = self.get_paths() path_indices = [] for i, path in enumerate(paths): if i == 0: index = 0 elif i in [1, 2]: index = NUM_OPS else: index = NUM_OPS + NUM_OPS ** 2 import pdb; pdb.set_trace() for j, op in enumerate(path): index += OPS.index(op) * NUM_OPS ** j path_indices.append(index) return tuple(path_indices) def encode_paths(self): """ output one-hot encoding of paths """ num_paths = sum([NUM_OPS ** i for i in range(1, LONGEST_PATH_LENGTH + 1)]) path_indices = self.get_path_indices() encoding = np.zeros(num_paths) for index in path_indices: encoding[index] = 1 return encoding def encode_freq_paths(self, cutoff=30): # natural cutoffs 5, 30, 155 (last) num_paths = sum([NUM_OPS ** i for i in range(1, LONGEST_PATH_LENGTH + 1)]) path_indices = self.get_path_indices() encoding = np.zeros(cutoff) for index in range(min(num_paths, cutoff)): if index in path_indices: encoding[index] = 1 return encoding def distance(self, other, dist_type, cutoff=30): if dist_type == 'adj': distance = adj_distance(self, other) elif dist_type == 'path': distance = path_distance(self, other) elif dist_type == 'trunc_path': distance = path_distance(self, other, cutoff=cutoff) elif dist_type == 'nasbot': distance = nasbot_distance(self, other) else: print('{} is an invalid distance'.format(distance)) raise NotImplementedError() return distance def get_neighborhood(self, nasbench, mutate_encoding, shuffle=True): nbhd = [] ops = self.get_op_dict() if mutate_encoding == 'adj': # OFA version mutation variation # https://github.com/mit-han-lab/once-for-all/blob/master/ofa/nas/search_algorithm/evolution.py for i in range(MAX_N_BLOCK): available_ks = [ks for ks in KS_LIST if ks != ops["ks"][i]] for ks in available_ks: new_ops = ops.copy() new_ops["ks"][i] = ks new_arch = {'string':self.get_string_from_ops(new_ops)} nbhd.append(new_arch) available_e = [e for e in EXPAND_LIST if e != ops["e"][i]] for e in available_e: new_ops = ops.copy() new_ops["e"][i] = e new_arch = {'string':self.get_string_from_ops(new_ops)} nbhd.append(new_arch) # for i in range(MAX_N_BLOCK): # available_ks = [ks for ks in KS_LIST if ks != ops["ks"][i]] # available_e = [e for e in EXPAND_LIST if e != ops["e"][i]] # for ks, e in zip(available_ks, available_e): # new_ops = ops.copy() # new_ops["ks"][i] = ks # new_ops["e"][i] = e # new_arch = {'string':self.get_string_from_ops(new_ops)} # nbhd.append(new_arch) for i in range(NUM_STAGE): available_d = [d for d in DEPTH_LIST if d != ops["d"][i]] for d in available_d: new_ops = ops.copy() new_ops["d"][i] = d new_arch = {'string':self.get_string_from_ops(new_ops)} nbhd.append(new_arch) # if mutate_encoding == 'adj': # for i in range(len(ops)): # import pdb; pdb.set_trace() # available = [op for op in OPS.keys() if op != ops[i]] # for op in available: # new_ops = ops.copy() # new_ops[i] = op # new_arch = {'string':self.get_string_from_ops(new_ops)} # nbhd.append(new_arch) elif mutate_encoding in ['path', 'trunc_path']: if mutate_encoding == 'trunc_path': path_blueprints = [[3], [0,4], [1,5]] else: path_blueprints = [[3], [0,4], [1,5], [0,2,5]] ops = self.get_op_dict() for blueprint in path_blueprints: for new_path in itertools.product(OPS, repeat=len(blueprint)): new_ops = ops.copy() for i, op in enumerate(new_path): new_ops[blueprint[i]] = op # check if it's the same same = True for j in range(len(ops)): if ops[j] != new_ops[j]: same = False if not same: new_arch = {'string':self.get_string_from_ops(new_ops)} nbhd.append(new_arch) else: print('{} is an invalid mutate encoding'.format(mutate_encoding)) raise NotImplementedError() if shuffle: random.shuffle(nbhd) return nbhd def get_unique_string(self): ops = self.get_op_dict() d_matrix = [] for i in range(NUM_STAGE): ds = ops['d'][i] for j in range(ds): d_matrix.append(ds) for j in range(MAX_LAYER_PER_STAGE - ds): d_matrix.append('none') string = '' for i, (ks, e, d) in enumerate(zip(ops['ks'], ops['e'], d_matrix)): if d == 'none': string += f'0-0-0_' else: string += f'{d}-{ks}-{e}_' return string[:-1]