diffusionNAG/MobileNetV3/main_exp/transfer_nag_lib/ofa_net.py
2024-03-15 14:38:51 +00:00

521 lines
18 KiB
Python

import numpy as np
import copy
import itertools
import random
import sys
import os
import pickle
import torch
from torch.nn.functional import one_hot
# from ofa.imagenet_classification.run_manager import RunManager
# from naszilla.nas_bench_201.distances import *
# INPUT = 'input'
# OUTPUT = 'output'
# OPS = ['avg_pool_3x3', 'nor_conv_1x1', 'nor_conv_3x3', 'none', 'skip_connect']
# NUM_OPS = len(OPS)
# OP_SPOTS = 20
KS_LIST = [3, 5, 7]
EXPAND_LIST = [3, 4, 6]
DEPTH_LIST = [2, 3, 4]
NUM_STAGE = 5
MAX_LAYER_PER_STAGE = 4
MAX_N_BLOCK= NUM_STAGE * MAX_LAYER_PER_STAGE # 20
OPS = {
'3-3': 0, '3-4': 1, '3-6': 2,
'5-3': 3, '5-4': 4, '5-6': 5,
'7-3': 6, '7-4': 7, '7-6': 8,
}
OPS2STR = {
0: '3-3', 1: '3-4', 2: '3-6',
3: '5-3', 4: '5-4', 5: '5-6',
6: '7-3', 7: '7-4', 8: '7-6',
}
NUM_OPS = len(OPS)
LONGEST_PATH_LENGTH = 20
# OP_SPOTS = NUM_VERTICES - 2
# OPS = {
# '3-3': 0, '3-4': 1, '3-6': 2,
# '5-3': 3, '5-4': 4, '5-6': 5,
# '7-3': 6, '7-4': 7, '7-6': 8,
# }
# OFA evolution hyper-parameters
# self.arch_mutate_prob = kwargs.get("arch_mutate_prob", 0.1)
# self.resolution_mutate_prob = kwargs.get("resolution_mutate_prob", 0.5)
# self.population_size = kwargs.get("population_size", 100)
# self.max_time_budget = kwargs.get("max_time_budget", 500)
# self.parent_ratio = kwargs.get("parent_ratio", 0.25)
# self.mutation_ratio = kwargs.get("mutation_ratio", 0.5)
class OFASubNet:
def __init__(self, string, accuracy_predictor=None):
self.string = string
self.accuracy_predictor = accuracy_predictor
# self.run_config = run_config
def get_string(self):
return self.string
def serialize(self):
return {
'string':self.string
}
@classmethod
def random_cell(cls,
nasbench,
max_nodes=None,
max_edges=None,
cutoff=None,
index_hash=None,
random_encoding=None):
"""
OFA sample random subnet
"""
# Randomly sample sub-networks from OFA network
random_subnet_config = nasbench.sample_active_subnet()
#{
# "ks": ks_setting,
# "e": expand_setting,
# "d": depth_setting,
# }
return {'string':cls.get_string_from_ops(random_subnet_config)}
def encode(self,
predictor_encoding,
nasbench=None,
deterministic=True,
cutoff=None,
nasbench_ours=None,
dataset=None):
if predictor_encoding == 'adj':
return self.encode_standard()
elif predictor_encoding == 'path':
raise NotImplementedError
return self.encode_paths()
elif predictor_encoding == 'trunc_path':
if not cutoff:
cutoff = 30
dic = self.gcn_encoding(nasbench,
deterministic=deterministic,
nasbench_ours=nasbench_ours,
dataset=dataset)
dic['trunc_path'] = self.encode_freq_paths(cutoff=cutoff)
return dic
# return self.encode_freq_paths(cutoff=cutoff)
elif predictor_encoding == 'gcn':
return self.gcn_encoding(nasbench,
deterministic=deterministic,
nasbench_ours=nasbench_ours,
dataset=dataset)
else:
print('{} is an invalid predictor encoding'.format(predictor_encoding))
raise NotImplementedError()
def get_ops_onehot(self):
ops = self.get_op_dict()
# ops = [INPUT, *ops, OUTPUT]
node_types = torch.zeros(NUM_STAGE * MAX_LAYER_PER_STAGE).long() # w/o in / out
num_vertices = len(OPS.values())
num_nodes = NUM_STAGE * MAX_LAYER_PER_STAGE
d_matrix = []
# import pdb; pdb.set_trace()
for i in range(NUM_STAGE):
ds = ops['d'][i]
for j in range(ds):
d_matrix.append(ds)
for j in range(MAX_LAYER_PER_STAGE - ds):
d_matrix.append('none')
for i, (ks, e, d) in enumerate(zip(
ops['ks'], ops['e'], d_matrix)):
if d == 'none':
# node_types[i] = OPS[d]
pass
else:
node_types[i] = OPS[f'{ks}-{e}']
ops_onehot = one_hot(node_types, num_vertices).float()
return ops_onehot
def gcn_encoding(self, nasbench, deterministic, nasbench_ours=None, dataset=None):
# op_map = [OUTPUT, INPUT, *OPS]
ops = self.get_op_dict()
# ops = [INPUT, *ops, OUTPUT]
node_types = torch.zeros(NUM_STAGE * MAX_LAYER_PER_STAGE).long() # w/o in / out
num_vertices = len(OPS.values())
num_nodes = NUM_STAGE * MAX_LAYER_PER_STAGE
d_matrix = []
# import pdb; pdb.set_trace()
for i in range(NUM_STAGE):
ds = ops['d'][i]
for j in range(ds):
d_matrix.append(ds)
for j in range(MAX_LAYER_PER_STAGE - ds):
d_matrix.append('none')
for i, (ks, e, d) in enumerate(zip(
ops['ks'], ops['e'], d_matrix)):
if d == 'none':
# node_types[i] = OPS[d]
pass
else:
node_types[i] = OPS[f'{ks}-{e}']
ops_onehot = one_hot(node_types, num_vertices).float()
val_loss = self.get_val_loss(nasbench, dataset=dataset)
test_loss = copy.deepcopy(val_loss)
# (num node, ops types) --> (20, 28)
def get_adj():
adj = torch.zeros(num_nodes, num_nodes)
for i in range(num_nodes-1):
adj[i, i+1] = 1
adj = np.array(adj)
return adj
matrix = get_adj()
dic = {
'num_vertices': num_vertices,
'adjacency': matrix,
'operations': ops_onehot,
'mask': np.array([i < num_vertices for i in range(num_vertices)], dtype=np.float32),
'val_acc': 1.0 - val_loss,
'test_acc': 1.0 - test_loss,
'x': ops_onehot
}
return dic
def get_runtime(self, nasbench, dataset='cifar100'):
return nasbench.query_by_index(index, dataset).get_eval('x-valid')['time']
def get_val_loss(self, nasbench, deterministic=1, dataset='cifar100'):
assert dataset == 'imagenet1k'
# SuperNet version
# ops = self.get_op_dict()
# nasbench.set_active_subnet(ks=ops['ks'], e=ops['e'], d=ops['d'])
# subnet = nasbench.get_active_subnet(preserve_weight=True)
# run_manager = RunManager(".tmp/eval_subnet", subnet, self.run_config, init=False)
# # assign image size: 128, 132, ..., 224
# self.run_config.data_provider.assign_active_img_size(224)
# run_manager.reset_running_statistics(net=subnet)
# loss, (top1, top5) = run_manager.validate(net=subnet)
# # print("Results: loss=%.5f,\t top1=%.1f,\t top5=%.1f" % (loss, top1, top5))
# self.loss = loss
# self.top1 = top1
# self.top5 = top5
# accuracy predictor version
ops = self.get_op_dict()
# resolutions = [160, 176, 192, 208, 224]
# ops['r'] = [random.choice(resolutions)]
ops['r'] = [224]
acc = self.accuracy_predictor.predict_accuracy([ops])[0][0].item()
return 1.0 - acc
def get_test_loss(self, nasbench, dataset='cifar100', deterministic=1):
ops = self.get_op_dict()
# resolutions = [160, 176, 192, 208, 224]
# ops['r'] = [random.choice(resolutions)]
ops['r'] = [224]
acc = self.accuracy_predictor.predict_accuracy([ops])[0][0].item()
return 1.0 - acc
def get_op_dict(self):
# given a string, get the list of operations
ops = {
"ks": [], "e": [], "d": []
}
tokens = self.string.split('_')
for i, token in enumerate(tokens):
d, ks, e = token.split('-')
if i % MAX_LAYER_PER_STAGE == 0:
ops['d'].append(int(d))
ops['ks'].append(int(ks))
ops['e'].append(int(e))
return ops
def get_num(self):
# compute the unique number of the architecture, in [0, 15624]
ops = self.get_op_dict()
index = 0
for i, op in enumerate(ops):
index += OPS.index(op) * NUM_OPS ** i
return index
def get_random_hash(self):
num = self.get_num()
hashes = pickle.load(open('nas_bench_201/random_hash.pkl', 'rb'))
return hashes[num]
@classmethod
def get_string_from_ops(cls, ops):
string = ''
for i, (ks, e) in enumerate(zip(ops['ks'], ops['e'])):
d = ops['d'][int(i/MAX_LAYER_PER_STAGE)]
string += f'{d}-{ks}-{e}_'
return string[:-1]
def perturb(self,
nasbench,
mutation_rate=1):
# deterministic version of mutate
ops = self.get_op_dict()
new_ops = []
num = np.random.choice(len(ops))
for i, op in enumerate(ops):
if i == num:
available = [o for o in OPS if o != op]
new_ops.append(np.random.choice(available))
else:
new_ops.append(op)
return {'string':self.get_string_from_ops(new_ops)}
def mutate(self,
nasbench,
mutation_rate=0.1,
mutate_encoding='adj',
index_hash=None,
cutoff=30,
patience=5000):
p = 0
mutation_rate = mutation_rate / 10 # OFA rate: 0.1
arch_dict = self.get_op_dict()
if mutate_encoding == 'adj':
# OFA version mutation
# https://github.com/mit-han-lab/once-for-all/blob/master/ofa/nas/search_algorithm/evolution.py
for i in range(MAX_N_BLOCK):
if random.random() < mutation_rate:
available_ks = [ks for ks in KS_LIST if ks != arch_dict["ks"][i]]
available_e = [e for e in EXPAND_LIST if e != arch_dict["e"][i]]
arch_dict["ks"][i] = random.choice(available_ks)
arch_dict["e"][i] = random.choice(available_e)
for i in range(NUM_STAGE):
if random.random() < mutation_rate:
available_d = [d for d in DEPTH_LIST if d != arch_dict["d"][i]]
arch_dict["d"][i] = random.choice(available_d)
return {'string':self.get_string_from_ops(arch_dict)}
elif mutate_encoding in ['path', 'trunc_path']:
raise NotImplementedError()
else:
print('{} is an invalid mutate encoding'.format(mutate_encoding))
raise NotImplementedError()
def encode_standard(self):
"""
compute the standard encoding
"""
ops = self.get_op_dict()
encoding = []
for i, (ks, e) in enumerate(zip(ops['ks'], ops['e'])):
string = f'{ks}-{e}'
encoding.append(OPS[string])
return encoding
def encode_one_hot(self):
"""
compute the one-hot encoding
"""
encoding = self.encode_standard()
one_hot = []
for num in encoding:
for i in range(len(OPS)):
if i == num:
one_hot.append(1)
else:
one_hot.append(0)
return one_hot
def get_num_params(self, nasbench):
# todo: add this method
return 100
def get_paths(self):
"""
return all paths from input to output
"""
path_blueprints = [[3], [0,4], [1,5], [0,2,5]]
ops = self.get_op_dict()
paths = []
for blueprint in path_blueprints:
paths.append([ops[node] for node in blueprint])
return paths
def get_path_indices(self):
"""
compute the index of each path
"""
paths = self.get_paths()
path_indices = []
for i, path in enumerate(paths):
if i == 0:
index = 0
elif i in [1, 2]:
index = NUM_OPS
else:
index = NUM_OPS + NUM_OPS ** 2
import pdb; pdb.set_trace()
for j, op in enumerate(path):
index += OPS.index(op) * NUM_OPS ** j
path_indices.append(index)
return tuple(path_indices)
def encode_paths(self):
""" output one-hot encoding of paths """
num_paths = sum([NUM_OPS ** i for i in range(1, LONGEST_PATH_LENGTH + 1)])
path_indices = self.get_path_indices()
encoding = np.zeros(num_paths)
for index in path_indices:
encoding[index] = 1
return encoding
def encode_freq_paths(self, cutoff=30):
# natural cutoffs 5, 30, 155 (last)
num_paths = sum([NUM_OPS ** i for i in range(1, LONGEST_PATH_LENGTH + 1)])
path_indices = self.get_path_indices()
encoding = np.zeros(cutoff)
for index in range(min(num_paths, cutoff)):
if index in path_indices:
encoding[index] = 1
return encoding
def distance(self, other, dist_type, cutoff=30):
if dist_type == 'adj':
distance = adj_distance(self, other)
elif dist_type == 'path':
distance = path_distance(self, other)
elif dist_type == 'trunc_path':
distance = path_distance(self, other, cutoff=cutoff)
elif dist_type == 'nasbot':
distance = nasbot_distance(self, other)
else:
print('{} is an invalid distance'.format(distance))
raise NotImplementedError()
return distance
def get_neighborhood(self,
nasbench,
mutate_encoding,
shuffle=True):
nbhd = []
ops = self.get_op_dict()
if mutate_encoding == 'adj':
# OFA version mutation variation
# https://github.com/mit-han-lab/once-for-all/blob/master/ofa/nas/search_algorithm/evolution.py
for i in range(MAX_N_BLOCK):
available_ks = [ks for ks in KS_LIST if ks != ops["ks"][i]]
for ks in available_ks:
new_ops = ops.copy()
new_ops["ks"][i] = ks
new_arch = {'string':self.get_string_from_ops(new_ops)}
nbhd.append(new_arch)
available_e = [e for e in EXPAND_LIST if e != ops["e"][i]]
for e in available_e:
new_ops = ops.copy()
new_ops["e"][i] = e
new_arch = {'string':self.get_string_from_ops(new_ops)}
nbhd.append(new_arch)
# for i in range(MAX_N_BLOCK):
# available_ks = [ks for ks in KS_LIST if ks != ops["ks"][i]]
# available_e = [e for e in EXPAND_LIST if e != ops["e"][i]]
# for ks, e in zip(available_ks, available_e):
# new_ops = ops.copy()
# new_ops["ks"][i] = ks
# new_ops["e"][i] = e
# new_arch = {'string':self.get_string_from_ops(new_ops)}
# nbhd.append(new_arch)
for i in range(NUM_STAGE):
available_d = [d for d in DEPTH_LIST if d != ops["d"][i]]
for d in available_d:
new_ops = ops.copy()
new_ops["d"][i] = d
new_arch = {'string':self.get_string_from_ops(new_ops)}
nbhd.append(new_arch)
# if mutate_encoding == 'adj':
# for i in range(len(ops)):
# import pdb; pdb.set_trace()
# available = [op for op in OPS.keys() if op != ops[i]]
# for op in available:
# new_ops = ops.copy()
# new_ops[i] = op
# new_arch = {'string':self.get_string_from_ops(new_ops)}
# nbhd.append(new_arch)
elif mutate_encoding in ['path', 'trunc_path']:
if mutate_encoding == 'trunc_path':
path_blueprints = [[3], [0,4], [1,5]]
else:
path_blueprints = [[3], [0,4], [1,5], [0,2,5]]
ops = self.get_op_dict()
for blueprint in path_blueprints:
for new_path in itertools.product(OPS, repeat=len(blueprint)):
new_ops = ops.copy()
for i, op in enumerate(new_path):
new_ops[blueprint[i]] = op
# check if it's the same
same = True
for j in range(len(ops)):
if ops[j] != new_ops[j]:
same = False
if not same:
new_arch = {'string':self.get_string_from_ops(new_ops)}
nbhd.append(new_arch)
else:
print('{} is an invalid mutate encoding'.format(mutate_encoding))
raise NotImplementedError()
if shuffle:
random.shuffle(nbhd)
return nbhd
def get_unique_string(self):
ops = self.get_op_dict()
d_matrix = []
for i in range(NUM_STAGE):
ds = ops['d'][i]
for j in range(ds):
d_matrix.append(ds)
for j in range(MAX_LAYER_PER_STAGE - ds):
d_matrix.append('none')
string = ''
for i, (ks, e, d) in enumerate(zip(ops['ks'], ops['e'], d_matrix)):
if d == 'none':
string += f'0-0-0_'
else:
string += f'{d}-{ks}-{e}_'
return string[:-1]