first commit
This commit is contained in:
9
.gitignore
vendored
Normal file
9
.gitignore
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
__pycache__
|
||||
checkpoints/
|
||||
*.pt
|
||||
data/
|
||||
exp/
|
||||
vis/
|
||||
results/
|
||||
.empty/
|
||||
.prev/
|
||||
9
MobileNetV3/all_path.py
Normal file
9
MobileNetV3/all_path.py
Normal file
@@ -0,0 +1,9 @@
|
||||
RAW_DATA_PATH="./data/ofa/raw_data"
|
||||
PROCESSED_DATA_PATH = "./data/ofa/data_transfer_nag"
|
||||
SCORE_MODEL_DATA_PATH="./data/ofa/data_score_model/ofa_database_500000.pt"
|
||||
SCORE_MODEL_DATA_IDX_PATH="./data/ofa/data_score_model/ridx-500000.pt"
|
||||
|
||||
NOISE_META_PREDICTOR_CKPT_PATH = "./checkpoints/ofa/noise_aware_meta_surrogate/model_best.pth.tar"
|
||||
SCORE_MODEL_CKPT_PATH="./checkpoints/ofa/score_model/model_best.pth.tar"
|
||||
UNNOISE_META_PREDICTOR_CKPT_PATH="./checkpoints/ofa/unnoised_meta_surrogate_from_metad2a"
|
||||
CONFIG_PATH='./configs/transfer_nag_ofa.pt'
|
||||
475
MobileNetV3/analysis/arch_functions.py
Normal file
475
MobileNetV3/analysis/arch_functions.py
Normal file
@@ -0,0 +1,475 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import wandb
|
||||
import igraph
|
||||
from torch.nn.functional import one_hot
|
||||
|
||||
|
||||
KS_LIST = [3, 5, 7]
|
||||
EXPAND_LIST = [3, 4, 6]
|
||||
DEPTH_LIST = [2, 3, 4]
|
||||
NUM_STAGE = 5
|
||||
MAX_LAYER_PER_STAGE = 4
|
||||
MAX_N_BLOCK= NUM_STAGE * MAX_LAYER_PER_STAGE # 20
|
||||
OPS = {
|
||||
'3-3': 0, '3-4': 1, '3-6': 2,
|
||||
'5-3': 3, '5-4': 4, '5-6': 5,
|
||||
'7-3': 6, '7-4': 7, '7-6': 8,
|
||||
}
|
||||
|
||||
OPS2STR = {
|
||||
0: '3-3', 1: '3-4', 2: '3-6',
|
||||
3: '5-3', 4: '5-4', 5: '5-6',
|
||||
6: '7-3', 7: '7-4', 8: '7-6',
|
||||
}
|
||||
NUM_OPS = len(OPS)
|
||||
LONGEST_PATH_LENGTH = 20
|
||||
|
||||
|
||||
class BasicArchMetricsOFA(object):
|
||||
def __init__(self, train_ds=None, train_arch_str_list=None, except_inout=False, data_root=None):
|
||||
if data_root is not None:
|
||||
self.ofa = torch.load(data_root)
|
||||
self.train_arch_list = self.ofa['x']
|
||||
else:
|
||||
self.ofa = None
|
||||
self.train_arch_list = None
|
||||
# self.ofa = torch.load(data_root)
|
||||
self.ops_decoder = OPS
|
||||
self.except_inout = except_inout
|
||||
|
||||
def get_string_from_onehot_x(self, x):
|
||||
# node_types = torch.nonzero(torch.tensor(x).long(), as_tuple=True)[1]
|
||||
x = torch.tensor(x)
|
||||
ds = torch.sum(x.view(NUM_STAGE, -1), dim=1)
|
||||
string = ''
|
||||
for i, _ in enumerate(x):
|
||||
if sum(_) == 0:
|
||||
string += '0-0-0_'
|
||||
else:
|
||||
string += f'{int(ds[int(i/MAX_LAYER_PER_STAGE)])}-' + OPS2STR[torch.nonzero(torch.tensor(_)).item()] + '_'
|
||||
return string[:-1]
|
||||
|
||||
|
||||
def compute_validity(self, generated, adj=None, mask=None):
|
||||
""" generated: list of couples (positions, node_types)"""
|
||||
valid = []
|
||||
error_types = []
|
||||
valid_str = []
|
||||
for x in generated:
|
||||
is_valid, error_type = is_valid_OFA_x(x)
|
||||
if is_valid:
|
||||
valid.append(torch.tensor(x).long())
|
||||
valid_str.append(self.get_string_from_onehot_x(x))
|
||||
else:
|
||||
error_types.append(error_type)
|
||||
|
||||
return valid, len(valid) / len(generated), valid_str, None, error_types
|
||||
|
||||
def compute_uniqueness(self, valid_arch):
|
||||
unique = []
|
||||
for x in valid_arch:
|
||||
if not any([torch.equal(x, tr_m) for tr_m in unique]):
|
||||
unique.append(x)
|
||||
return unique, len(unique) / len(valid_arch)
|
||||
|
||||
def compute_novelty(self, unique):
|
||||
num_novel = 0
|
||||
novel = []
|
||||
if self.train_arch_list is None:
|
||||
print("Dataset arch_str is None, novelty computation skipped")
|
||||
return 1, 1
|
||||
for arch in unique:
|
||||
if not any([torch.equal(arch, tr_m) for tr_m in self.train_arch_list]):
|
||||
# if arch not in self.train_arch_list[1:]:
|
||||
novel.append(arch)
|
||||
num_novel += 1
|
||||
return novel, num_novel / len(unique)
|
||||
|
||||
def evaluate(self, generated, adj, mask, check_dataname='cifar10'):
|
||||
""" generated: list of pairs """
|
||||
valid_arch, validity, _, _, error_types = self.compute_validity(generated, adj, mask)
|
||||
|
||||
print(f"Validity over {len(generated)} archs: {validity * 100 :.2f}%")
|
||||
error_1 = torch.sum(torch.tensor(error_types) == 1) / len(generated)
|
||||
error_2 = torch.sum(torch.tensor(error_types) == 2) / len(generated)
|
||||
error_3 = torch.sum(torch.tensor(error_types) == 3) / len(generated)
|
||||
print(f"Unvalid-Multi_Node_Type over {len(generated)} archs: {error_1 * 100 :.2f}%")
|
||||
print(f"INVALID_1OR2 over {len(generated)} archs: {error_2 * 100 :.2f}%")
|
||||
print(f"INVALID_3AND4 over {len(generated)} archs: {error_3 * 100 :.2f}%")
|
||||
# print(f"Number of connected components of {len(generated)} molecules: min:{nc_min:.2f} mean:{nc_mu:.2f} max:{nc_max:.2f}")
|
||||
|
||||
if validity > 0:
|
||||
unique, uniqueness = self.compute_uniqueness(valid_arch)
|
||||
print(f"Uniqueness over {len(valid_arch)} valid archs: {uniqueness * 100 :.2f}%")
|
||||
|
||||
if self.train_arch_list is not None:
|
||||
_, novelty = self.compute_novelty(unique)
|
||||
print(f"Novelty over {len(unique)} unique valid archs: {novelty * 100 :.2f}%")
|
||||
else:
|
||||
novelty = -1.0
|
||||
|
||||
else:
|
||||
novelty = -1.0
|
||||
uniqueness = 0.0
|
||||
unique = []
|
||||
|
||||
test_acc_list, flops_list, params_list, latency_list = [0], [0], [0], [0]
|
||||
all_arch_str = None
|
||||
return ([validity, uniqueness, novelty, error_1, error_2, error_3],
|
||||
unique,
|
||||
dict(test_acc_list=test_acc_list, flops_list=flops_list, params_list=params_list, latency_list=latency_list),
|
||||
all_arch_str)
|
||||
|
||||
|
||||
class BasicArchMetricsMetaOFA(object):
|
||||
def __init__(self, train_ds=None, train_arch_str_list=None, except_inout=False, data_root=None):
|
||||
if data_root is not None:
|
||||
self.ofa = torch.load(data_root)
|
||||
self.train_arch_list = self.ofa['x']
|
||||
else:
|
||||
self.ofa = None
|
||||
self.train_arch_list = None
|
||||
self.ops_decoder = OPS
|
||||
|
||||
def get_string_from_onehot_x(self, x):
|
||||
x = torch.tensor(x)
|
||||
ds = torch.sum(x.view(NUM_STAGE, -1), dim=1)
|
||||
string = ''
|
||||
for i, _ in enumerate(x):
|
||||
if sum(_) == 0:
|
||||
string += '0-0-0_'
|
||||
else:
|
||||
string += f'{int(ds[int(i/MAX_LAYER_PER_STAGE)])}-' + OPS2STR[torch.nonzero(torch.tensor(_)).item()] + '_'
|
||||
return string[:-1]
|
||||
|
||||
def compute_validity(self, generated, adj=None, mask=None):
|
||||
""" generated: list of couples (positions, node_types)"""
|
||||
valid = []
|
||||
valid_arch_str = []
|
||||
all_arch_str = []
|
||||
error_types = []
|
||||
for x in generated:
|
||||
is_valid, error_type = is_valid_OFA_x(x)
|
||||
if is_valid:
|
||||
valid.append(torch.tensor(x).long())
|
||||
arch_str = self.get_string_from_onehot_x(x)
|
||||
valid_arch_str.append(arch_str)
|
||||
else:
|
||||
arch_str = None
|
||||
error_types.append(error_type)
|
||||
all_arch_str.append(arch_str)
|
||||
validity = 0 if len(generated) == 0 else (len(valid)/len(generated))
|
||||
return valid, validity, valid_arch_str, all_arch_str, error_types
|
||||
|
||||
def compute_uniqueness(self, valid_arch):
|
||||
unique = []
|
||||
for x in valid_arch:
|
||||
if not any([torch.equal(x, tr_m) for tr_m in unique]):
|
||||
unique.append(x)
|
||||
return unique, len(unique) / len(valid_arch)
|
||||
|
||||
def compute_novelty(self, unique):
|
||||
num_novel = 0
|
||||
novel = []
|
||||
if self.train_arch_list is None:
|
||||
print("Dataset arch_str is None, novelty computation skipped")
|
||||
return 1, 1
|
||||
for arch in unique:
|
||||
if not any([torch.equal(arch, tr_m) for tr_m in self.train_arch_list]):
|
||||
novel.append(arch)
|
||||
num_novel += 1
|
||||
return novel, num_novel / len(unique)
|
||||
|
||||
def evaluate(self, generated, adj, mask, check_dataname='imagenet1k'):
|
||||
""" generated: list of pairs """
|
||||
valid_arch, validity, _, _, error_types = self.compute_validity(generated, adj, mask)
|
||||
|
||||
print(f"Validity over {len(generated)} archs: {validity * 100 :.2f}%")
|
||||
error_1 = torch.sum(torch.tensor(error_types) == 1) / len(generated)
|
||||
error_2 = torch.sum(torch.tensor(error_types) == 2) / len(generated)
|
||||
error_3 = torch.sum(torch.tensor(error_types) == 3) / len(generated)
|
||||
print(f"Unvalid-Multi_Node_Type over {len(generated)} archs: {error_1 * 100 :.2f}%")
|
||||
print(f"INVALID_1OR2 over {len(generated)} archs: {error_2 * 100 :.2f}%")
|
||||
print(f"INVALID_3AND4 over {len(generated)} archs: {error_3 * 100 :.2f}%")
|
||||
|
||||
if validity > 0:
|
||||
unique, uniqueness = self.compute_uniqueness(valid_arch)
|
||||
print(f"Uniqueness over {len(valid_arch)} valid archs: {uniqueness * 100 :.2f}%")
|
||||
|
||||
if self.train_arch_list is not None:
|
||||
_, novelty = self.compute_novelty(unique)
|
||||
print(f"Novelty over {len(unique)} unique valid archs: {novelty * 100 :.2f}%")
|
||||
else:
|
||||
novelty = -1.0
|
||||
|
||||
else:
|
||||
novelty = -1.0
|
||||
uniqueness = 0.0
|
||||
unique = []
|
||||
|
||||
test_acc_list, flops_list, params_list, latency_list = [0], [0], [0], [0]
|
||||
all_arch_str = None
|
||||
return ([validity, uniqueness, novelty, error_1, error_2, error_3],
|
||||
unique,
|
||||
dict(test_acc_list=test_acc_list, flops_list=flops_list, params_list=params_list, latency_list=latency_list),
|
||||
all_arch_str)
|
||||
|
||||
|
||||
def get_arch_acc_info(nasbench201, arch, dataname='cifar10'):
|
||||
arch_index = nasbench201['str'].index(arch)
|
||||
test_acc = nasbench201['test-acc'][dataname][arch_index]
|
||||
flops = nasbench201['flops'][dataname][arch_index]
|
||||
params = nasbench201['params'][dataname][arch_index]
|
||||
latency = nasbench201['latency'][dataname][arch_index]
|
||||
return test_acc, flops, params, latency
|
||||
|
||||
|
||||
def get_arch_acc_info_meta(nasbench201, arch, dataname='cifar10'):
|
||||
arch_index = nasbench201['str'].index(arch)
|
||||
flops = nasbench201['flops'][dataname][arch_index]
|
||||
params = nasbench201['params'][dataname][arch_index]
|
||||
latency = nasbench201['latency'][dataname][arch_index]
|
||||
if 'cifar' in dataname:
|
||||
test_acc = nasbench201['test-acc'][dataname][arch_index]
|
||||
else:
|
||||
# TODO
|
||||
test_acc = None
|
||||
return arch_index, test_acc, flops, params, latency
|
||||
|
||||
|
||||
def is_valid_DAG(g, START_TYPE=0, END_TYPE=1):
|
||||
res = g.is_dag()
|
||||
n_start, n_end = 0, 0
|
||||
for v in g.vs:
|
||||
if v['type'] == START_TYPE:
|
||||
n_start += 1
|
||||
elif v['type'] == END_TYPE:
|
||||
n_end += 1
|
||||
if v.indegree() == 0 and v['type'] != START_TYPE:
|
||||
return False
|
||||
if v.outdegree() == 0 and v['type'] != END_TYPE:
|
||||
return False
|
||||
return res and n_start == 1 and n_end == 1
|
||||
|
||||
def check_single_node_type(x):
|
||||
for x_elem in x:
|
||||
if int(np.sum(x_elem)) != 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_start_end_nodes(x, START_TYPE, END_TYPE):
|
||||
if x[0][START_TYPE] != 1:
|
||||
return False
|
||||
if x[-1][END_TYPE] != 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_interm_node_types(x, START_TYPE, END_TYPE):
|
||||
for x_elem in x[1:-1]:
|
||||
if x_elem[START_TYPE] == 1:
|
||||
return False
|
||||
if x_elem[END_TYPE] == 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def construct_igraph(node_type, edge_type, ops_decoder, except_inout=True):
|
||||
assert node_type.shape[0] == edge_type.shape[0]
|
||||
|
||||
START_TYPE = ops_decoder.index('input')
|
||||
END_TYPE = ops_decoder.index('output')
|
||||
|
||||
g = igraph.Graph(directed=True)
|
||||
for i, node in enumerate(node_type):
|
||||
new_type = node.item()
|
||||
g.add_vertex(type=new_type)
|
||||
if new_type == END_TYPE:
|
||||
end_vertices = set([v.index for v in g.vs.select(_outdegree_eq=0) if v.index != g.vcount()-1])
|
||||
for v in end_vertices:
|
||||
g.add_edge(v, i)
|
||||
elif i > 0:
|
||||
for ek in range(i):
|
||||
ek_score = edge_type[ek][i].item()
|
||||
if ek_score >= 0.5:
|
||||
g.add_edge(ek, i)
|
||||
|
||||
return g
|
||||
|
||||
|
||||
def compute_arch_metrics(arch_list, adj, mask, train_arch_str_list,
|
||||
train_ds, timestep=None, name=None, except_inout=False, data_root=None):
|
||||
""" arch_list: (dict) """
|
||||
metrics = BasicArchMetricsOFA(data_root=data_root)
|
||||
arch_metrics = metrics.evaluate(arch_list, adj, mask, check_dataname='cifar10')
|
||||
all_arch_str = arch_metrics[-1]
|
||||
|
||||
if wandb.run:
|
||||
arch_prop = arch_metrics[2]
|
||||
test_acc_list = arch_prop['test_acc_list']
|
||||
flops_list = arch_prop['flops_list']
|
||||
params_list = arch_prop['params_list']
|
||||
latency_list = arch_prop['latency_list']
|
||||
if arch_metrics[0][1] > 0.: # uniquness > 0.
|
||||
dic = {
|
||||
'Validity': arch_metrics[0][0], 'Uniqueness': arch_metrics[0][1], 'Novelty': arch_metrics[0][2],
|
||||
'test_acc_max': np.max(test_acc_list), 'test_acc_min': np.min(test_acc_list), 'test_acc_mean': np.mean(test_acc_list), 'test_acc_std': np.std(test_acc_list),
|
||||
'flops_max': np.max(flops_list), 'flops_min': np.min(flops_list), 'flops_mean': np.mean(flops_list), 'flops_std': np.std(flops_list),
|
||||
'params_max': np.max(params_list), 'params_min': np.min(params_list), 'params_mean': np.mean(params_list), 'params_std': np.std(params_list),
|
||||
'latency_max': np.max(latency_list), 'latency_min': np.min(latency_list), 'latency_mean': np.mean(latency_list), 'latency_std': np.std(latency_list),
|
||||
}
|
||||
else:
|
||||
dic = {
|
||||
'Validity': arch_metrics[0][0], 'Uniqueness': arch_metrics[0][1], 'Novelty': arch_metrics[0][2],
|
||||
'test_acc_max': -1, 'test_acc_min': -1, 'test_acc_mean': -1, 'test_acc_std': 0,
|
||||
'flops_max': -1, 'flops_min': -1, 'flops_mean': -1, 'flops_std': 0,
|
||||
'params_max': -1, 'params_min': -1, 'params_mean': -1, 'params_std': 0,
|
||||
'latency_max': -1, 'latency_min': -1, 'latency_mean': -1, 'latency_std': 0,
|
||||
}
|
||||
if timestep is not None:
|
||||
dic.update({'step': timestep})
|
||||
|
||||
wandb.log(dic)
|
||||
|
||||
return arch_metrics, all_arch_str
|
||||
|
||||
def compute_arch_metrics_meta(
|
||||
arch_list, adj, mask, train_arch_str_list, train_ds,
|
||||
timestep=None, check_dataname='cifar10', name=None):
|
||||
""" arch_list: (dict) """
|
||||
|
||||
metrics = BasicArchMetricsMetaOFA(train_ds, train_arch_str_list)
|
||||
arch_metrics = metrics.evaluate(arch_list, adj, mask, check_dataname=check_dataname)
|
||||
if wandb.run:
|
||||
arch_prop = arch_metrics[2]
|
||||
if name != 'ofa':
|
||||
arch_idx_list = arch_prop['arch_idx_list']
|
||||
test_acc_list = arch_prop['test_acc_list']
|
||||
flops_list = arch_prop['flops_list']
|
||||
params_list = arch_prop['params_list']
|
||||
latency_list = arch_prop['latency_list']
|
||||
if arch_metrics[0][1] > 0.: # uniquness > 0.
|
||||
dic = {
|
||||
'Validity': arch_metrics[0][0], 'Uniqueness': arch_metrics[0][1], 'Novelty': arch_metrics[0][2],
|
||||
'test_acc_max': np.max(test_acc_list), 'test_acc_min': np.min(test_acc_list), 'test_acc_mean': np.mean(test_acc_list), 'test_acc_std': np.std(test_acc_list),
|
||||
'flops_max': np.max(flops_list), 'flops_min': np.min(flops_list), 'flops_mean': np.mean(flops_list), 'flops_std': np.std(flops_list),
|
||||
'params_max': np.max(params_list), 'params_min': np.min(params_list), 'params_mean': np.mean(params_list), 'params_std': np.std(params_list),
|
||||
'latency_max': np.max(latency_list), 'latency_min': np.min(latency_list), 'latency_mean': np.mean(latency_list), 'latency_std': np.std(latency_list),
|
||||
}
|
||||
else:
|
||||
dic = {
|
||||
'Validity': arch_metrics[0][0], 'Uniqueness': arch_metrics[0][1], 'Novelty': arch_metrics[0][2],
|
||||
'test_acc_max': -1, 'test_acc_min': -1, 'test_acc_mean': -1, 'test_acc_std': 0,
|
||||
'flops_max': -1, 'flops_min': -1, 'flops_mean': -1, 'flops_std': 0,
|
||||
'params_max': -1, 'params_min': -1, 'params_mean': -1, 'params_std': 0,
|
||||
'latency_max': -1, 'latency_min': -1, 'latency_mean': -1, 'latency_std': 0,
|
||||
}
|
||||
if timestep is not None:
|
||||
dic.update({'step': timestep})
|
||||
|
||||
return arch_metrics
|
||||
|
||||
|
||||
def check_multiple_nodes(x):
|
||||
assert len(x.shape) == 2
|
||||
for x_elem in x:
|
||||
x_elem = np.array(x_elem)
|
||||
if int(np.sum(x_elem)) > 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_inout_node(x, START_TYPE=0, END_TYPE=1):
|
||||
assert len(x.shape) == 2
|
||||
return x[0][START_TYPE] == 1 and x[-1][END_TYPE] == 1
|
||||
|
||||
def check_none_in_1_and_2_layers(x, NONE_TYPE=None):
|
||||
assert len(x.shape) == 2
|
||||
first_and_second_layers = [0, 1, 4, 5, 8, 9, 12, 13, 16, 17]
|
||||
for layer in first_and_second_layers:
|
||||
if int(np.sum(x[layer])) == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_none_in_3_and_4_layers(x, NONE_TYPE=None):
|
||||
assert len(x.shape) == 2
|
||||
third_layers = [2, 6, 10, 14, 18]
|
||||
|
||||
for layer in third_layers:
|
||||
if int(np.sum(x[layer])) == 0:
|
||||
if int(np.sum(x[layer+1])) != 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_interm_inout_node(x, START_TYPE, END_TYPE):
|
||||
for x_elem in x[1:-1]:
|
||||
if x_elem[START_TYPE] == 1:
|
||||
return False
|
||||
if x_elem[END_TYPE] == 1:
|
||||
return False
|
||||
|
||||
|
||||
def is_valid_OFA_x(x):
|
||||
ERORR = {
|
||||
'MULIPLE_NODES': 1,
|
||||
'INVALID_1OR2_LAYERS': 2,
|
||||
'INVALID_3AND4_LAYERS': 3,
|
||||
'NO_ERROR': -1
|
||||
}
|
||||
if not check_multiple_nodes(x):
|
||||
return False, ERORR['MULIPLE_NODES']
|
||||
|
||||
if not check_none_in_1_and_2_layers(x):
|
||||
return False, ERORR['INVALID_1OR2_LAYERS']
|
||||
|
||||
if not check_none_in_3_and_4_layers(x):
|
||||
return False, ERORR['INVALID_3AND4_LAYERS']
|
||||
|
||||
return True, ERORR['NO_ERROR']
|
||||
|
||||
|
||||
def get_x_adj_from_opsdict_ofa(ops):
|
||||
node_types = torch.zeros(NUM_STAGE * MAX_LAYER_PER_STAGE).long() # w/o in / out
|
||||
num_vertices = len(OPS.values())
|
||||
num_nodes = NUM_STAGE * MAX_LAYER_PER_STAGE
|
||||
d_matrix = []
|
||||
|
||||
for i in range(NUM_STAGE):
|
||||
ds = ops['d'][i]
|
||||
for j in range(ds):
|
||||
d_matrix.append(ds)
|
||||
|
||||
for j in range(MAX_LAYER_PER_STAGE - ds):
|
||||
d_matrix.append('none')
|
||||
|
||||
for i, (ks, e, d) in enumerate(zip(
|
||||
ops['ks'], ops['e'], d_matrix)):
|
||||
if d == 'none':
|
||||
pass
|
||||
else:
|
||||
node_types[i] = OPS[f'{ks}-{e}']
|
||||
|
||||
x = one_hot(node_types, num_vertices).float()
|
||||
|
||||
def get_adj():
|
||||
adj = torch.zeros(num_nodes, num_nodes)
|
||||
for i in range(num_nodes-1):
|
||||
adj[i, i+1] = 1
|
||||
adj = np.array(adj)
|
||||
return adj
|
||||
|
||||
adj = get_adj()
|
||||
return x, adj
|
||||
|
||||
|
||||
def get_string_from_onehot_x(x):
|
||||
x = torch.tensor(x)
|
||||
ds = torch.sum(x.view(NUM_STAGE, -1), dim=1)
|
||||
string = ''
|
||||
for i, _ in enumerate(x):
|
||||
if sum(_) == 0:
|
||||
string += '0-0-0_'
|
||||
else:
|
||||
string += f'{int(ds[int(i/MAX_LAYER_PER_STAGE)])}-' + OPS2STR[torch.nonzero(torch.tensor(_)).item()] + '_'
|
||||
return string[:-1]
|
||||
114
MobileNetV3/analysis/arch_metrics.py
Normal file
114
MobileNetV3/analysis/arch_metrics.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from analysis.arch_functions import compute_arch_metrics, compute_arch_metrics_meta
|
||||
from torch import Tensor
|
||||
import wandb
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SamplingArchMetrics(nn.Module):
|
||||
def __init__(self, config, train_ds, exp_name):
|
||||
super().__init__()
|
||||
|
||||
self.exp_name = exp_name
|
||||
self.train_ds = train_ds
|
||||
if config.data.name == 'ofa':
|
||||
self.train_arch_str_list = train_ds.x_list_
|
||||
else:
|
||||
self.train_arch_str_list = train_ds.arch_str_list_
|
||||
self.name = config.data.name
|
||||
self.except_inout = config.data.except_inout
|
||||
self.data_root = config.data.root
|
||||
|
||||
|
||||
def forward(self, arch_list: list, adj, mask, this_sample_dir, test=False, timestep=None):
|
||||
"""_summary_
|
||||
:params arch_list: list of archs
|
||||
:params adj: [batch_size, num_nodes, num_nodes]
|
||||
:params mask: [batch_size, num_nodes, num_nodes]
|
||||
"""
|
||||
arch_metrics, all_arch_str = compute_arch_metrics(
|
||||
arch_list, adj, mask, self.train_arch_str_list, self.train_ds, timestep=timestep,
|
||||
name=self.name, except_inout=self.except_inout, data_root=self.data_root)
|
||||
# arch_metrics
|
||||
# ([validity, uniqueness, novelty],
|
||||
# unique,
|
||||
# dict(test_acc_list=test_acc_list, flops_list=flops_list, params_list=params_list, latency_list=latency_list),
|
||||
# all_arch_str)
|
||||
|
||||
if test and self.name != 'ofa':
|
||||
with open(r'final_.txt', 'w') as fp:
|
||||
for arch_str in all_arch_str:
|
||||
# write each item on a new line
|
||||
fp.write("%s\n" % arch_str)
|
||||
print('All archs saved')
|
||||
|
||||
if self.name != 'ofa':
|
||||
valid_unique_arch = arch_metrics[1]
|
||||
valid_unique_arch_prop_dict = arch_metrics[2] # test_acc, flops, params, latency
|
||||
# textfile = open(f'{this_sample_dir}/archs/{name}/valid_unique_arch_step-{current_step}.txt', "w")
|
||||
textfile = open(f'{this_sample_dir}/valid_unique_archs.txt', "w")
|
||||
for i in range(len(valid_unique_arch)):
|
||||
textfile.write(f"Arch: {valid_unique_arch[i]} \n")
|
||||
textfile.write(f"Test Acc: {valid_unique_arch_prop_dict['test_acc_list'][i]} \n")
|
||||
textfile.write(f"FLOPs: {valid_unique_arch_prop_dict['flops_list'][i]} \n ")
|
||||
textfile.write(f"#Params: {valid_unique_arch_prop_dict['params_list'][i]} \n")
|
||||
textfile.write(f"Latency: {valid_unique_arch_prop_dict['latency_list'][i]} \n \n")
|
||||
textfile.writelines(valid_unique_arch)
|
||||
textfile.close()
|
||||
|
||||
# res_dic = {
|
||||
# 'Validity': arch_metrics[0][0], 'Uniqueness': arch_metrics[0][1], 'Novelty': arch_metrics[0][2],
|
||||
# 'test_acc_max': -1, 'test_acc_min':-1, 'test_acc_mean': -1, 'test_acc_std': 0,
|
||||
# 'flops_max': -1, 'flops_min':-1, 'flops_mean': -1, 'flops_std': 0,
|
||||
# 'params_max': -1, 'params_min':-1, 'params_mean': -1, 'params_std': 0,
|
||||
# 'latency_max': -1, 'latency_min':-1, 'latency_mean': -1, 'latency_std': 0,
|
||||
# }
|
||||
|
||||
return arch_metrics
|
||||
|
||||
class SamplingArchMetricsMeta(nn.Module):
|
||||
def __init__(self, config, train_ds, exp_name, train_index=None, nasbench=None):
|
||||
super().__init__()
|
||||
|
||||
self.exp_name = exp_name
|
||||
self.train_ds = train_ds
|
||||
self.search_space = config.data.name
|
||||
if self.search_space == 'ofa':
|
||||
self.train_arch_str_list = None
|
||||
else:
|
||||
self.train_arch_str_list = [train_ds.arch_str_list[i] for i in train_ds.idx_lst['train']]
|
||||
|
||||
def forward(self, arch_list: list, adj, mask, this_sample_dir, test=False,
|
||||
timestep=None, check_dataname='cifar10'):
|
||||
"""_summary_
|
||||
:params arch_list: list of archs
|
||||
:params adj: [batch_size, num_nodes, num_nodes]
|
||||
:params mask: [batch_size, num_nodes, num_nodes]
|
||||
"""
|
||||
arch_metrics = compute_arch_metrics_meta(arch_list, adj, mask, self.train_arch_str_list,
|
||||
self.train_ds, timestep=timestep, check_dataname=check_dataname,
|
||||
name=self.search_space)
|
||||
all_arch_str = arch_metrics[-1]
|
||||
|
||||
if test:
|
||||
with open(r'final_.txt', 'w') as fp:
|
||||
for arch_str in all_arch_str:
|
||||
# write each item on a new line
|
||||
fp.write("%s\n" % arch_str)
|
||||
print('All archs saved')
|
||||
|
||||
valid_unique_arch = arch_metrics[1] # arch_str
|
||||
valid_unique_arch_prop_dict = arch_metrics[2] # test_acc, flops, params, latency
|
||||
# textfile = open(f'{this_sample_dir}/archs/{name}/valid_unique_arch_step-{current_step}.txt', "w")
|
||||
if self.search_space != 'ofa':
|
||||
textfile = open(f'{this_sample_dir}/valid_unique_archs.txt', "w")
|
||||
for i in range(len(valid_unique_arch)):
|
||||
textfile.write(f"Arch: {valid_unique_arch[i]} \n")
|
||||
textfile.write(f"Arch Index: {valid_unique_arch_prop_dict['arch_idx_list'][i]} \n")
|
||||
textfile.write(f"Test Acc: {valid_unique_arch_prop_dict['test_acc_list'][i]} \n")
|
||||
textfile.write(f"FLOPs: {valid_unique_arch_prop_dict['flops_list'][i]} \n ")
|
||||
textfile.write(f"#Params: {valid_unique_arch_prop_dict['params_list'][i]} \n")
|
||||
textfile.write(f"Latency: {valid_unique_arch_prop_dict['latency_list'][i]} \n \n")
|
||||
textfile.writelines(valid_unique_arch)
|
||||
textfile.close()
|
||||
|
||||
return arch_metrics
|
||||
547
MobileNetV3/analysis/visualization.py
Normal file
547
MobileNetV3/analysis/visualization.py
Normal file
@@ -0,0 +1,547 @@
|
||||
import os
|
||||
import torch
|
||||
import imageio
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
# import rdkit.Chem
|
||||
import wandb
|
||||
import matplotlib.pyplot as plt
|
||||
# import igraph
|
||||
# import pygraphviz as pgv
|
||||
import datasets_nas
|
||||
from configs.ckpt import DATAROOT_NB201
|
||||
|
||||
|
||||
class ArchVisualization:
|
||||
def __init__(self, config, remove_none=False, exp_name=None):
|
||||
self.config = config
|
||||
self.remove_none = remove_none
|
||||
self.exp_name = exp_name
|
||||
self.num_graphs_to_visualize = config.log.num_graphs_to_visualize
|
||||
self.nasbench201 = torch.load(DATAROOT_NB201)
|
||||
|
||||
self.labels = {
|
||||
0: 'input',
|
||||
1: 'output',
|
||||
2: 'conv3',
|
||||
3: 'sep3',
|
||||
4: 'conv5',
|
||||
5: 'sep5',
|
||||
6: 'avg3',
|
||||
7: 'max3',
|
||||
}
|
||||
|
||||
self.colors = ['skyblue', 'pink', 'yellow', 'orange', 'greenyellow', 'green', 'azure', 'beige']
|
||||
|
||||
|
||||
def to_networkx_directed(self, node_list, adjacency_matrix):
|
||||
"""
|
||||
Convert graphs to neural architectures
|
||||
node_list: the nodes of a batch of nodes (bs x n)
|
||||
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
|
||||
"""
|
||||
|
||||
|
||||
graph = nx.DiGraph()
|
||||
# add nodes to the graph
|
||||
for i in range(len(node_list)):
|
||||
if node_list[i] == -1:
|
||||
continue
|
||||
graph.add_node(i, number=i, symbol=node_list[i], color_val=node_list[i])
|
||||
|
||||
rows, cols = np.where(torch.triu(torch.tensor(adjacency_matrix), diagonal=1).numpy() >= 1)
|
||||
edges = zip(rows.tolist(), cols.tolist())
|
||||
for edge in edges:
|
||||
edge_type = adjacency_matrix[edge[0]][edge[1]]
|
||||
graph.add_edge(edge[0], edge[1], color=float(edge_type), weight=3 * edge_type)
|
||||
|
||||
return graph
|
||||
|
||||
def visualize_non_molecule(self, graph, pos, path, iterations=100, node_size=1200, largest_component=False):
|
||||
if largest_component:
|
||||
CGs = [graph.subgraph(c) for c in nx.connected_components(graph)]
|
||||
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
|
||||
graph = CGs[0]
|
||||
|
||||
# Plot the graph structure with colors
|
||||
if pos is None:
|
||||
pos = nx.nx_pydot.graphviz_layout(graph, prog="dot")
|
||||
# pos = nx.multipartite_layout(graph, subset_key='number')
|
||||
# pos = nx.spring_layout(graph, iterations=iterations)
|
||||
|
||||
# Set node colors based on the operations
|
||||
|
||||
plt.figure()
|
||||
nx.draw(graph, pos=pos, labels=self.labels, arrows=True, node_shape="s",
|
||||
node_size=node_size, node_color=self.colors, edge_color='grey', with_labels=True)
|
||||
# nx.draw(graph, pos, font_size=5, node_size=node_size, with_labels=False, node_color=U[:, 1],
|
||||
# cmap=plt.cm.coolwarm, vmin=vmin, vmax=vmax, edge_color='grey')
|
||||
# import pdb; pdb.set_trace()
|
||||
# plt.tight_layout()
|
||||
|
||||
plt.savefig(path)
|
||||
plt.close("all")
|
||||
|
||||
def visualize(self, path: str, graphs: list, log='graph', adj=None):
|
||||
# define path to save figures
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
# visualize the final molecules
|
||||
for i in range(self.num_graphs_to_visualize):
|
||||
file_path = os.path.join(path, 'graph_{}.png'.format(i))
|
||||
graph = self.to_networkx_directed(graphs[i], adj[0].detach().cpu().numpy())
|
||||
self.visualize_non_molecule(graph, pos=None, path=file_path)
|
||||
im = plt.imread(file_path)
|
||||
if wandb.run and log is not None:
|
||||
wandb.log({log: [wandb.Image(im, caption=file_path)]})
|
||||
|
||||
def visualize_chain(self, path, sample_list, adjacency_matrix,
|
||||
r_valid_chain, r_uniqueness_chain, r_novel_chain):
|
||||
import pdb; pdb.set_trace()
|
||||
# convert graphs to networkx
|
||||
graphs = [self.to_networkx_directed(sample_list[i], adjacency_matrix[i]) for i in range(sample_list.shape[0])]
|
||||
# find the coordinates of atoms in the final molecule
|
||||
final_graph = graphs[-1]
|
||||
final_pos = nx.nx_pydot.graphviz_layout(final_graph, prog="dot")
|
||||
# final_pos = None
|
||||
|
||||
# draw gif
|
||||
save_paths = []
|
||||
num_frams = sample_list
|
||||
|
||||
for frame in range(num_frams):
|
||||
file_name = os.path.join(path, 'frame_{}.png'.format(frame))
|
||||
self.visualize_non_molecule(graphs[frame], pos=final_pos, path=file_name)
|
||||
save_paths.append(file_name)
|
||||
|
||||
imgs = [imageio.imread(fn) for fn in save_paths]
|
||||
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
|
||||
print(f'==> Save gif at {gif_path}')
|
||||
imgs.extend([imgs[-1]] * 10)
|
||||
imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5)
|
||||
if wandb.run:
|
||||
wandb.log({'chain': [wandb.Video(gif_path, caption=gif_path, format="gif")]})
|
||||
|
||||
|
||||
def visualize_chain_vun(self, path, r_valid_chain, r_unique_chain, r_novel_chain, sde, sampling_eps, number_chain_steps=None):
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
# timesteps = torch.linspace(sampling_eps, sde.T, sde.N)
|
||||
timesteps = torch.linspace(sde.T, sampling_eps, sde.N)
|
||||
|
||||
if number_chain_steps is not None:
|
||||
timesteps_ = []
|
||||
n = int(sde.N / number_chain_steps)
|
||||
for i, t in enumerate(timesteps):
|
||||
if i % n == n - 1:
|
||||
timesteps_.append(t.item())
|
||||
# timesteps_ = [t for i, t in enumerate(timesteps) if i % n == n-1]
|
||||
assert len(timesteps_) == number_chain_steps
|
||||
timesteps_ = timesteps_[::-1]
|
||||
|
||||
else:
|
||||
timesteps_ = list(timesteps.numpy())[::-1]
|
||||
|
||||
# validity
|
||||
plt.clf()
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(timesteps_, r_valid_chain, color='red')
|
||||
ax.set_title(f'Validity')
|
||||
ax.set_xlabel('time')
|
||||
ax.set_ylabel('Validity')
|
||||
plt.show()
|
||||
file_path = os.path.join(path, 'validity.png')
|
||||
plt.savefig(file_path)
|
||||
plt.close("all")
|
||||
print(f'==> Save scatter plot at {file_path}')
|
||||
im = plt.imread(file_path)
|
||||
if wandb.run:
|
||||
wandb.log({'r_valid_chains': [wandb.Image(im, caption=file_path)]})
|
||||
|
||||
# Uniqueness
|
||||
plt.clf()
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(timesteps_, r_unique_chain, color='green')
|
||||
ax.set_title(f'Uniqueness')
|
||||
ax.set_xlabel('time')
|
||||
ax.set_ylabel('Uniqueness')
|
||||
plt.show()
|
||||
file_path = os.path.join(path, 'uniquness.png')
|
||||
plt.savefig(file_path)
|
||||
plt.close("all")
|
||||
print(f'==> Save scatter plot at {file_path}')
|
||||
im = plt.imread(file_path)
|
||||
if wandb.run:
|
||||
wandb.log({'r_uniqueness_chains': [wandb.Image(im, caption=file_path)]})
|
||||
|
||||
# Novelty
|
||||
plt.clf()
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(timesteps_, r_novel_chain, color='blue')
|
||||
ax.set_title(f'Novelty')
|
||||
ax.set_xlabel('time')
|
||||
ax.set_ylabel('Novelty')
|
||||
file_path = os.path.join(path, 'novelty.png')
|
||||
plt.savefig(file_path)
|
||||
plt.close("all")
|
||||
print(f'==> Save scatter plot at {file_path}')
|
||||
im = plt.imread(file_path)
|
||||
if wandb.run:
|
||||
wandb.log({'r_novelty_chains': [wandb.Image(im, caption=file_path)]})
|
||||
|
||||
|
||||
def visualize_grad_norm(self, path, score_grad_norm_p, classifier_grad_norm_p,
|
||||
score_grad_norm_c, classifier_grad_norm_c, sde, sampling_eps,
|
||||
number_chain_steps=None):
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
# timesteps = torch.linspace(sampling_eps, sde.T, sde.N)
|
||||
timesteps = torch.linspace(sde.T, sampling_eps, sde.N)
|
||||
timesteps_ = list(timesteps.numpy())[::-1]
|
||||
|
||||
if len(score_grad_norm_c) == 0:
|
||||
score_grad_norm_c = [-1] * len(score_grad_norm_p)
|
||||
if len(classifier_grad_norm_c) == 0:
|
||||
classifier_grad_norm_c = [-1] * len(classifier_grad_norm_p)
|
||||
|
||||
plt.clf()
|
||||
fig, ax1 = plt.subplots()
|
||||
|
||||
color_1 = 'red'
|
||||
ax1.set_title(f'grad_norm (predictor)')
|
||||
ax1.set_xlabel('time')
|
||||
ax1.set_ylabel('score_grad_norm (predictor)', color=color_1)
|
||||
ax1.plot(timesteps_, score_grad_norm_p, color=color_1)
|
||||
ax1.tick_params(axis='y', labelcolor=color_1)
|
||||
|
||||
ax2 = ax1.twinx()
|
||||
color_2 = 'blue'
|
||||
ax2.set_ylabel('classifier_grad_norm (predictor)', color=color_2)
|
||||
ax2.plot(timesteps_, classifier_grad_norm_p, color=color_2)
|
||||
ax2.tick_params(axis='y', labelcolor=color_2)
|
||||
fig.tight_layout()
|
||||
plt.show()
|
||||
|
||||
file_path = os.path.join(path, 'grad_norm_p.png')
|
||||
plt.savefig(file_path)
|
||||
plt.close("all")
|
||||
print(f'==> Save scatter plot at {file_path}')
|
||||
im = plt.imread(file_path)
|
||||
if wandb.run:
|
||||
wandb.log({'grad_norm_p': [wandb.Image(im, caption=file_path)]})
|
||||
|
||||
|
||||
plt.clf()
|
||||
fig, ax1 = plt.subplots()
|
||||
|
||||
color_1 = 'green'
|
||||
ax1.set_title(f'grad_norm (corrector)')
|
||||
ax1.set_xlabel('time')
|
||||
ax1.set_ylabel('score_grad_norm (corrector)', color=color_1)
|
||||
ax1.plot(timesteps_, score_grad_norm_c, color=color_1)
|
||||
ax1.tick_params(axis='y', labelcolor=color_1)
|
||||
|
||||
ax2 = ax1.twinx()
|
||||
color_2 = 'yellow'
|
||||
ax2.set_ylabel('classifier_grad_norm (corrector)', color=color_2)
|
||||
ax2.plot(timesteps_, classifier_grad_norm_c, color=color_2)
|
||||
ax2.tick_params(axis='y', labelcolor=color_2)
|
||||
fig.tight_layout()
|
||||
plt.show()
|
||||
|
||||
file_path = os.path.join(path, 'grad_norm_c.png')
|
||||
plt.savefig(file_path)
|
||||
plt.close("all")
|
||||
print(f'==> Save scatter plot at {file_path}')
|
||||
im = plt.imread(file_path)
|
||||
if wandb.run:
|
||||
wandb.log({'grad_norm_c': [wandb.Image(im, caption=file_path)]})
|
||||
|
||||
|
||||
def visualize_scatter(self, path,
|
||||
score_config, classifier_config,
|
||||
sampled_arch_metric, plot_textstr=True,
|
||||
x_axis='latency', y_axis='test-acc', x_label='Latency (ms)', y_label='Accuracy (%)',
|
||||
log='scatter', check_dataname='cifar10-valid',
|
||||
selected_arch_idx_list_topN=None, selected_arch_idx_list=None,
|
||||
train_idx_list=None, return_file_path=False):
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
tg_dataset = classifier_config.data.tg_dataset
|
||||
|
||||
train_ds_s, eval_ds_s, test_ds_s = datasets_nas.get_dataset(score_config)
|
||||
if selected_arch_idx_list is None:
|
||||
train_ds_c, eval_ds_c, test_ds_c = datasets_nas.get_dataset(classifier_config)
|
||||
else:
|
||||
train_ds_c, eval_ds_c, test_ds_c = datasets_nas.get_dataset_iter(classifier_config)
|
||||
|
||||
plt.clf()
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
# entire architectures
|
||||
entire_ds_x = train_ds_s.get_unnoramlized_entire_data(x_axis, tg_dataset)
|
||||
entire_ds_y = train_ds_s.get_unnoramlized_entire_data(y_axis, tg_dataset)
|
||||
ax.scatter(entire_ds_x, entire_ds_y, color = 'lightgray', alpha = 0.5, label='Entire', marker=',')
|
||||
|
||||
# architectures trained by the score_model
|
||||
# train_ds_s_x = train_ds_s.get_unnoramlized_data(x_axis, tg_dataset)
|
||||
# train_ds_s_y = train_ds_s.get_unnoramlized_data(y_axis, tg_dataset)
|
||||
# ax.scatter(train_ds_s_x, train_ds_s_y, color = 'gray', alpha = 0.8, label='Trained by Score Model')
|
||||
|
||||
# architectures trained by the classifier
|
||||
train_ds_c_x = train_ds_c.get_unnoramlized_data(x_axis, tg_dataset)
|
||||
train_ds_c_y = train_ds_c.get_unnoramlized_data(y_axis, tg_dataset)
|
||||
ax.scatter(train_ds_c_x, train_ds_c_y, color = 'black', alpha = 0.8, label='Trained by Predictor Model')
|
||||
|
||||
# oracle
|
||||
oracle_idx = torch.argmax(torch.tensor(entire_ds_y)).item()
|
||||
# oracle_idx = torch.argmax(torch.tensor(train_ds_s.get_unnoramlized_entire_data('val-acc', tg_dataset))).item()
|
||||
oracle_item_x = entire_ds_x[oracle_idx]
|
||||
oracle_item_y = entire_ds_y[oracle_idx]
|
||||
ax.scatter(oracle_item_x, oracle_item_y, color = 'red', alpha = 1.0, label='Oracle', marker='*', s=150)
|
||||
|
||||
# architectures sampled by the score_model & classifier
|
||||
AXIS_TO_PROP = {
|
||||
'val-acc': 'val_acc_list',
|
||||
'test-acc': 'test_acc_list',
|
||||
'latency': 'latency_list',
|
||||
'flops': 'flops_list',
|
||||
'params': 'params_list',
|
||||
}
|
||||
sampled_ds_c_x = sampled_arch_metric[2][AXIS_TO_PROP[x_axis]]
|
||||
sampled_ds_c_y = sampled_arch_metric[2][AXIS_TO_PROP[y_axis]]
|
||||
ax.scatter(sampled_ds_c_x, sampled_ds_c_y, color = 'limegreen', alpha = 0.8, label='Sampled', marker='x')
|
||||
|
||||
ax.set_title(f'{tg_dataset.upper()} Dataset')
|
||||
ax.set_xlabel(x_label)
|
||||
ax.set_ylabel(y_label)
|
||||
|
||||
|
||||
if selected_arch_idx_list_topN is not None:
|
||||
selected_arch_topN_info_dict = get_arch_acc_info_dict(
|
||||
self.nasbench201, dataname=check_dataname, arch_index_list=selected_arch_idx_list_topN)
|
||||
selected_topN_ds_x = selected_arch_topN_info_dict[AXIS_TO_PROP[x_axis]]
|
||||
selected_topN_ds_y = selected_arch_topN_info_dict[AXIS_TO_PROP[y_axis]]
|
||||
ax.scatter(selected_topN_ds_x, selected_topN_ds_y, color = 'pink', alpha = 0.8, label='Selected_topN', marker='x')
|
||||
|
||||
# architectures selected by the prdictor
|
||||
selected_ds_x, selected_ds_y = None, None
|
||||
if selected_arch_idx_list is not None:
|
||||
selected_arch_info_dict = get_arch_acc_info_dict(
|
||||
self.nasbench201, dataname=check_dataname, arch_index_list=selected_arch_idx_list)
|
||||
selected_ds_x = selected_arch_info_dict[AXIS_TO_PROP[x_axis]]
|
||||
selected_ds_y = selected_arch_info_dict[AXIS_TO_PROP[y_axis]]
|
||||
ax.scatter(selected_ds_x, selected_ds_y, color = 'blue', alpha = 0.8, label='Selected', marker='x')
|
||||
|
||||
if plot_textstr:
|
||||
textstr = self.get_textstr(sampled_arch_metric=sampled_arch_metric,
|
||||
sampled_ds_c_x=sampled_ds_c_x, sampled_ds_c_y=sampled_ds_c_y,
|
||||
x_axis=x_axis, y_axis=y_axis,
|
||||
classifier_config=classifier_config,
|
||||
selected_ds_x=selected_ds_x, selected_ds_y=selected_ds_y,
|
||||
selected_topN_ds_x=selected_topN_ds_x, selected_topN_ds_y=selected_topN_ds_y,
|
||||
oracle_idx=oracle_idx, train_idx_list=train_idx_list
|
||||
)
|
||||
|
||||
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
|
||||
ax.text(0.6, 0.4, textstr, transform=ax.transAxes, verticalalignment='bottom', bbox=props, fontsize='x-small')
|
||||
# ax.text(textstr, transform=ax.transAxes, verticalalignment='bottom', bbox=props)
|
||||
ax.legend(loc="lower right")
|
||||
|
||||
plt.subplots_adjust(left=0, bottom=0, right=1, top=1)
|
||||
plt.show()
|
||||
plt.tight_layout()
|
||||
|
||||
file_path = os.path.join(path, 'scatter.png')
|
||||
plt.savefig(file_path)
|
||||
plt.close("all")
|
||||
print(f'==> Save scatter plot at {path}')
|
||||
|
||||
if return_file_path:
|
||||
return file_path
|
||||
|
||||
im = plt.imread(file_path)
|
||||
if wandb.run and log is not None:
|
||||
wandb.log({log: [wandb.Image(im, caption=file_path)]})
|
||||
|
||||
# if return_selected_arch_info_dict:
|
||||
# return selected_arch_info_dict, selected_arch_topN_info_dict
|
||||
|
||||
def visualize_scatter_chain(self, path, score_config, classifier_config, sampled_arch_metric_chain, plot_textstr=True,
|
||||
x_axis='latency', y_axis='test-acc', x_label='Latency (ms)', y_label='Accuracy (%)',
|
||||
log='scatter_chain'):
|
||||
|
||||
# draw gif
|
||||
os.makedirs(path, exist_ok=True)
|
||||
save_paths = []
|
||||
num_frames = len(sampled_arch_metric_chain)
|
||||
|
||||
tg_dataset = classifier_config.data.tg_dataset
|
||||
|
||||
train_ds_s, eval_ds_s, test_ds_s = datasets_nas.get_dataset(score_config)
|
||||
train_ds_c, eval_ds_c, test_ds_c = datasets_nas.get_dataset(classifier_config)
|
||||
|
||||
# entire architectures
|
||||
entire_ds_x = train_ds_s.get_unnoramlized_entire_data(x_axis, tg_dataset)
|
||||
entire_ds_y = train_ds_s.get_unnoramlized_entire_data(y_axis, tg_dataset)
|
||||
|
||||
# architectures trained by the score_model
|
||||
train_ds_s_x = train_ds_s.get_unnoramlized_data(x_axis, tg_dataset)
|
||||
train_ds_s_y = train_ds_s.get_unnoramlized_data(y_axis, tg_dataset)
|
||||
|
||||
# architectures trained by the classifier
|
||||
train_ds_c_x = train_ds_c.get_unnoramlized_data(x_axis, tg_dataset)
|
||||
train_ds_c_y = train_ds_c.get_unnoramlized_data(y_axis, tg_dataset)
|
||||
|
||||
# oracle
|
||||
# oracle_idx = torch.argmax(torch.tensor(entire_ds_y)).item()
|
||||
oracle_idx = torch.argmax(torch.tensor(train_ds_s.get_unnoramlized_entire_data('val-acc', tg_dataset))).item()
|
||||
oracle_item_x = entire_ds_x[oracle_idx]
|
||||
oracle_item_y = entire_ds_y[oracle_idx]
|
||||
|
||||
for frame in range(num_frames):
|
||||
sampled_arch_metric = sampled_arch_metric_chain[frame]
|
||||
|
||||
plt.clf()
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
# entire architectures
|
||||
ax.scatter(entire_ds_x, entire_ds_y, color = 'lightgray', alpha = 0.5, label='Entire', marker=',')
|
||||
# architectures trained by the score_model
|
||||
ax.scatter(train_ds_s_x, train_ds_s_y, color = 'gray', alpha = 0.8, label='Trained by Score Model')
|
||||
# architectures trained by the classifier
|
||||
ax.scatter(train_ds_c_x, train_ds_c_y, color = 'black', alpha = 0.8, label='Trained by Predictor Model')
|
||||
# oracle
|
||||
ax.scatter(oracle_item_x, oracle_item_y, color = 'red', alpha = 1.0, label='Oracle', marker='*', s=150)
|
||||
# architectures sampled by the score_model & classifier
|
||||
AXIS_TO_PROP = {
|
||||
'test-acc': 'test_acc_list',
|
||||
'latency': 'latency_list',
|
||||
'flops': 'flops_list',
|
||||
'params': 'params_list',
|
||||
}
|
||||
sampled_ds_c_x = sampled_arch_metric[2][AXIS_TO_PROP[x_axis]]
|
||||
sampled_ds_c_y = sampled_arch_metric[2][AXIS_TO_PROP[y_axis]]
|
||||
ax.scatter(sampled_ds_c_x, sampled_ds_c_y, color = 'limegreen', alpha = 0.8, label='Sampled', marker='x')
|
||||
|
||||
ax.set_title(f'{tg_dataset.upper()} Dataset')
|
||||
ax.set_xlabel(x_label)
|
||||
ax.set_ylabel(y_label)
|
||||
|
||||
if plot_textstr:
|
||||
textstr = self.get_textstr(sampled_arch_metric, sampled_ds_c_x, sampled_ds_c_y,
|
||||
x_axis, y_axis, classifier_config)
|
||||
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
|
||||
ax.text(0.6, 0.3, textstr, transform=ax.transAxes, verticalalignment='bottom', bbox=props)
|
||||
# ax.text(textstr, transform=ax.transAxes, verticalalignment='bottom', bbox=props)
|
||||
ax.legend(loc="lower right")
|
||||
|
||||
plt.subplots_adjust(left=0, bottom=0, right=1, top=1)
|
||||
plt.show()
|
||||
# plt.tight_layout()
|
||||
|
||||
file_path = os.path.join(path, f'frame_{frame}.png')
|
||||
plt.savefig(file_path)
|
||||
plt.close("all")
|
||||
print(f'==> Save scatter plot at {file_path}')
|
||||
save_paths.append(file_path)
|
||||
|
||||
im = plt.imread(file_path)
|
||||
if wandb.run and log is not None:
|
||||
wandb.log({log: [wandb.Image(im, caption=file_path)]})
|
||||
|
||||
# draw gif
|
||||
imgs = [imageio.imread(fn) for fn in save_paths[::-1]]
|
||||
# gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
|
||||
gif_path = os.path.join(path, f'scatter.gif')
|
||||
print(f'==> Save gif at {gif_path}')
|
||||
imgs.extend([imgs[-1]] * 10)
|
||||
# imgs.extend([imgs[0]] * 10)
|
||||
imageio.mimsave(gif_path, imgs, subrectangles=True, fps=5)
|
||||
if wandb.run:
|
||||
wandb.log({'chain_gif': [wandb.Video(gif_path, caption=gif_path, format="gif")]})
|
||||
|
||||
def get_textstr(self,
|
||||
sampled_arch_metric,
|
||||
sampled_ds_c_x, sampled_ds_c_y,
|
||||
x_axis='latency', y_axis='test-acc',
|
||||
classifier_config=None,
|
||||
selected_ds_x=None, selected_ds_y=None,
|
||||
selected_topN_ds_x=None, selected_topN_ds_y=None,
|
||||
oracle_idx=None, train_idx_list=None):
|
||||
mean_v_x = round(np.mean(np.array(sampled_ds_c_x)), 4)
|
||||
std_v_x = round(np.std(np.array(sampled_ds_c_x)), 4)
|
||||
max_v_x = round(np.max(np.array(sampled_ds_c_x)), 4)
|
||||
min_v_x = round(np.min(np.array(sampled_ds_c_x)), 4)
|
||||
|
||||
mean_v_y = round(np.mean(np.array(sampled_ds_c_y)), 4)
|
||||
std_v_y = round(np.std(np.array(sampled_ds_c_y)), 4)
|
||||
max_v_y = round(np.max(np.array(sampled_ds_c_y)), 4)
|
||||
min_v_y = round(np.min(np.array(sampled_ds_c_y)), 4)
|
||||
|
||||
if selected_ds_x is not None:
|
||||
mean_v_x_s = round(np.mean(np.array(selected_ds_x)), 4)
|
||||
std_v_x_s = round(np.std(np.array(selected_ds_x)), 4)
|
||||
max_v_x_s = round(np.max(np.array(selected_ds_x)), 4)
|
||||
min_v_x_s = round(np.min(np.array(selected_ds_x)), 4)
|
||||
|
||||
if selected_ds_y is not None:
|
||||
mean_v_y_s = round(np.mean(np.array(selected_ds_y)), 4)
|
||||
std_v_y_s = round(np.std(np.array(selected_ds_y)), 4)
|
||||
max_v_y_s = round(np.max(np.array(selected_ds_y)), 4)
|
||||
min_v_y_s = round(np.min(np.array(selected_ds_y)), 4)
|
||||
|
||||
textstr = ''
|
||||
r_valid, r_unique, r_novel = round(sampled_arch_metric[0][0], 4), round(sampled_arch_metric[0][1], 4), round(sampled_arch_metric[0][2], 4)
|
||||
textstr += f'V-{r_valid} | U-{r_unique} | N-{r_novel} \n'
|
||||
textstr += f'Predictor (Noise-aware-{str(classifier_config.training.noised)[0]}, k={self.config.sampling.classifier_scale}) \n'
|
||||
textstr += f'=> Sampled {x_axis} \n'
|
||||
textstr += f'Mean-{mean_v_x} | Std-{std_v_x} \n'
|
||||
textstr += f'Max-{max_v_x} | Min-{min_v_x} \n'
|
||||
textstr += f'=> Sampled {y_axis} \n'
|
||||
textstr += f'Mean-{mean_v_y} | Std-{std_v_y} \n'
|
||||
textstr += f'Max-{max_v_y} | Min-{min_v_y} \n'
|
||||
if selected_ds_x is not None:
|
||||
textstr += f'==> Selected {x_axis} \n'
|
||||
textstr += f'Mean-{mean_v_x_s} | Std-{std_v_x_s} \n'
|
||||
textstr += f'Max-{max_v_x_s} | Min-{min_v_x_s} \n'
|
||||
if selected_ds_y is not None:
|
||||
textstr += f'==> Selected {y_axis} \n'
|
||||
textstr += f'Mean-{mean_v_y_s} | Std-{std_v_y_s} \n'
|
||||
textstr += f'Max-{max_v_y_s} | Min-{min_v_y_s} \n'
|
||||
if selected_topN_ds_y is not None:
|
||||
textstr += f'==> Predicted TopN (10) -{str(round(max(selected_topN_ds_y[:10]), 4))} \n'
|
||||
|
||||
if train_idx_list is not None and oracle_idx in train_idx_list:
|
||||
textstr += f'==> Hit Oracle ({oracle_idx}) !'
|
||||
|
||||
return textstr
|
||||
|
||||
|
||||
def get_arch_acc_info_dict(nasbench201, dataname='cifar10-valid', arch_index_list=None):
|
||||
val_acc_list = []
|
||||
test_acc_list = []
|
||||
flops_list = []
|
||||
params_list = []
|
||||
latency_list = []
|
||||
|
||||
for arch_index in arch_index_list:
|
||||
val_acc = nasbench201['val-acc'][dataname][arch_index]
|
||||
val_acc_list.append(val_acc)
|
||||
test_acc = nasbench201['test-acc'][dataname][arch_index]
|
||||
test_acc_list.append(test_acc)
|
||||
flops = nasbench201['flops'][dataname][arch_index]
|
||||
flops_list.append(flops)
|
||||
params = nasbench201['params'][dataname][arch_index]
|
||||
params_list.append(params)
|
||||
latency = nasbench201['latency'][dataname][arch_index]
|
||||
latency_list.append(latency)
|
||||
|
||||
return {
|
||||
'val_acc_list': val_acc_list,
|
||||
'test_acc_list': test_acc_list,
|
||||
'flops_list': flops_list,
|
||||
'params_list': params_list,
|
||||
'latency_list': latency_list
|
||||
}
|
||||
167
MobileNetV3/configs/tr_meta_surrogate_ofa.py
Normal file
167
MobileNetV3/configs/tr_meta_surrogate_ofa.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import ml_collections
|
||||
import torch
|
||||
from all_path import SCORE_MODEL_CKPT_PATH, SCORE_MODEL_DATA_PATH
|
||||
|
||||
|
||||
def get_config():
|
||||
config = ml_collections.ConfigDict()
|
||||
|
||||
config.search_space = None
|
||||
|
||||
# genel
|
||||
config.resume = False
|
||||
config.folder_name = 'DiffusionNAG'
|
||||
config.task = 'tr_meta_predictor'
|
||||
config.exp_name = None
|
||||
config.model_type = 'meta_predictor'
|
||||
config.scorenet_ckpt_path = SCORE_MODEL_CKPT_PATH
|
||||
config.is_meta = True
|
||||
|
||||
# training
|
||||
config.training = training = ml_collections.ConfigDict()
|
||||
training.sde = 'vesde'
|
||||
training.continuous = True
|
||||
training.reduce_mean = True
|
||||
training.noised = True
|
||||
|
||||
training.batch_size = 128
|
||||
training.eval_batch_size = 512
|
||||
training.n_iters = 20000
|
||||
training.snapshot_freq = 500
|
||||
training.log_freq = 500
|
||||
training.eval_freq = 500
|
||||
## store additional checkpoints for preemption
|
||||
training.snapshot_freq_for_preemption = 1000
|
||||
## produce samples at each snapshot.
|
||||
training.snapshot_sampling = True
|
||||
training.likelihood_weighting = False
|
||||
# training for perturbed data
|
||||
training.t_spot = 1.
|
||||
# training from pretrained score model
|
||||
training.load_pretrained = False
|
||||
training.pretrained_model_path = SCORE_MODEL_CKPT_PATH
|
||||
|
||||
# sampling
|
||||
config.sampling = sampling = ml_collections.ConfigDict()
|
||||
sampling.method = 'pc'
|
||||
sampling.predictor = 'euler_maruyama'
|
||||
sampling.corrector = 'none'
|
||||
# sampling.corrector = 'langevin'
|
||||
sampling.rtol = 1e-5
|
||||
sampling.atol = 1e-5
|
||||
sampling.ode_method = 'dopri5' # 'rk4'
|
||||
sampling.ode_step = 0.01
|
||||
|
||||
sampling.n_steps_each = 1
|
||||
sampling.noise_removal = True
|
||||
sampling.probability_flow = False
|
||||
sampling.snr = 0.16
|
||||
sampling.vis_row = 4
|
||||
sampling.vis_col = 4
|
||||
|
||||
# conditional
|
||||
sampling.classifier_scale = 1.0
|
||||
sampling.regress = True
|
||||
sampling.labels = 'max'
|
||||
sampling.weight_ratio = False
|
||||
sampling.weight_scheduling = False
|
||||
sampling.t_spot = 1.
|
||||
sampling.t_spot_end = 0.
|
||||
sampling.number_chain_steps = 50
|
||||
sampling.check_dataname = 'imagenet1k'
|
||||
|
||||
# evaluation
|
||||
config.eval = evaluate = ml_collections.ConfigDict()
|
||||
evaluate.begin_ckpt = 5
|
||||
evaluate.end_ckpt = 20
|
||||
# evaluate.batch_size = 512
|
||||
evaluate.batch_size = 128
|
||||
evaluate.enable_sampling = True
|
||||
evaluate.num_samples = 1024
|
||||
evaluate.mmd_distance = 'RBF'
|
||||
evaluate.max_subgraph = False
|
||||
evaluate.save_graph = False
|
||||
|
||||
# data
|
||||
config.data = data = ml_collections.ConfigDict()
|
||||
data.centered = True
|
||||
data.dequantization = False
|
||||
|
||||
data.root = SCORE_MODEL_DATA_PATH
|
||||
data.name = 'ofa'
|
||||
data.split_ratio = 0.8
|
||||
data.dataset_idx = 'random'
|
||||
data.max_node = 20
|
||||
data.n_vocab = 9
|
||||
data.START_TYPE = 0
|
||||
data.END_TYPE = 1
|
||||
data.num_graphs = 100000
|
||||
data.num_channels = 1
|
||||
data.except_inout = False # ignore
|
||||
data.triu_adj = True
|
||||
data.connect_prev = False
|
||||
data.tg_dataset = None
|
||||
data.label_list = ['meta-acc']
|
||||
# aug_mask
|
||||
data.aug_mask_algo = 'none' # 'long_range' | 'floyd'
|
||||
# num_train
|
||||
data.num_train = 150
|
||||
|
||||
# model
|
||||
config.model = model = ml_collections.ConfigDict()
|
||||
model.name = 'MetaPredictorCATE'
|
||||
model.ema_rate = 0.9999
|
||||
model.normalization = 'GroupNorm'
|
||||
model.nonlinearity = 'swish'
|
||||
model.nf = 128
|
||||
model.num_gnn_layers = 4
|
||||
model.size_cond = False
|
||||
model.embedding_type = 'positional'
|
||||
model.rw_depth = 16
|
||||
model.graph_layer = 'PosTransLayer'
|
||||
model.edge_th = -1.
|
||||
model.heads = 8
|
||||
model.attn_clamp = False
|
||||
#############################################################################
|
||||
# meta
|
||||
model.input_type = 'DA'
|
||||
model.hs = 512
|
||||
model.nz = 56
|
||||
model.num_sample = 20
|
||||
|
||||
model.num_scales = 1000
|
||||
model.beta_min = 0.1
|
||||
model.beta_max = 5.0
|
||||
model.sigma_min = 0.1
|
||||
model.sigma_max = 5.0
|
||||
model.dropout = 0.1
|
||||
# graph encoder
|
||||
config.model.graph_encoder = graph_encoder = ml_collections.ConfigDict()
|
||||
graph_encoder.n_layers = 2
|
||||
graph_encoder.d_model = 64
|
||||
graph_encoder.n_head = 2
|
||||
graph_encoder.d_ff = 32
|
||||
graph_encoder.dropout = 0.1
|
||||
graph_encoder.n_vocab = 9
|
||||
|
||||
# optimization
|
||||
config.optim = optim = ml_collections.ConfigDict()
|
||||
optim.weight_decay = 0
|
||||
optim.optimizer = 'Adam'
|
||||
optim.lr = 0.001
|
||||
optim.beta1 = 0.9
|
||||
optim.eps = 1e-8
|
||||
optim.warmup = 1000
|
||||
optim.grad_clip = 1.
|
||||
|
||||
config.seed = 42
|
||||
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
# log
|
||||
config.log = log = ml_collections.ConfigDict()
|
||||
log.use_wandb = True
|
||||
log.wandb_project_name = 'DiffusionNAG'
|
||||
log.log_valid_sample_prop = False
|
||||
log.num_graphs_to_visualize = 20
|
||||
|
||||
return config
|
||||
141
MobileNetV3/configs/tr_scorenet_ofa.py
Normal file
141
MobileNetV3/configs/tr_scorenet_ofa.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Training PGSN on Community Small Dataset with GraphGDP"""
|
||||
|
||||
import ml_collections
|
||||
import torch
|
||||
|
||||
|
||||
def get_config():
|
||||
config = ml_collections.ConfigDict()
|
||||
|
||||
# general
|
||||
config.resume = False
|
||||
config.resume_ckpt_path = './exp'
|
||||
config.folder_name = 'tr_scorenet'
|
||||
config.task = 'tr_scorenet'
|
||||
config.exp_name = None
|
||||
|
||||
config.model_type = 'sde'
|
||||
|
||||
# training
|
||||
config.training = training = ml_collections.ConfigDict()
|
||||
training.sde = 'vesde'
|
||||
training.continuous = True
|
||||
training.reduce_mean = True
|
||||
|
||||
training.batch_size = 256
|
||||
training.eval_batch_size = 1000
|
||||
training.n_iters = 1000000
|
||||
training.snapshot_freq = 10000
|
||||
training.log_freq = 200
|
||||
training.eval_freq = 10000
|
||||
## store additional checkpoints for preemption
|
||||
training.snapshot_freq_for_preemption = 5000
|
||||
## produce samples at each snapshot.
|
||||
training.snapshot_sampling = True
|
||||
training.likelihood_weighting = False
|
||||
|
||||
# sampling
|
||||
config.sampling = sampling = ml_collections.ConfigDict()
|
||||
sampling.method = 'pc'
|
||||
sampling.predictor = 'euler_maruyama'
|
||||
sampling.corrector = 'none'
|
||||
sampling.rtol = 1e-5
|
||||
sampling.atol = 1e-5
|
||||
sampling.ode_method = 'dopri5' # 'rk4'
|
||||
sampling.ode_step = 0.01
|
||||
|
||||
sampling.n_steps_each = 1
|
||||
sampling.noise_removal = True
|
||||
sampling.probability_flow = False
|
||||
sampling.snr = 0.16
|
||||
sampling.vis_row = 4
|
||||
sampling.vis_col = 4
|
||||
sampling.alpha = 0.5
|
||||
sampling.qtype = 'threshold'
|
||||
|
||||
# evaluation
|
||||
config.eval = evaluate = ml_collections.ConfigDict()
|
||||
evaluate.begin_ckpt = 5
|
||||
evaluate.end_ckpt = 20
|
||||
evaluate.batch_size = 1024
|
||||
evaluate.enable_sampling = True
|
||||
evaluate.num_samples = 1024
|
||||
evaluate.mmd_distance = 'RBF'
|
||||
evaluate.max_subgraph = False
|
||||
evaluate.save_graph = False
|
||||
|
||||
# data
|
||||
config.data = data = ml_collections.ConfigDict()
|
||||
data.centered = True
|
||||
data.dequantization = False
|
||||
|
||||
data.root = './data/ofa/data_score_model/ofa_database_500000.pt'
|
||||
data.name = 'ofa'
|
||||
data.split_ratio = 0.9
|
||||
data.dataset_idx = 'random'
|
||||
data.max_node = 20
|
||||
data.n_vocab = 9 # 10 #
|
||||
data.START_TYPE = 0
|
||||
data.END_TYPE = 1
|
||||
data.num_graphs = 100000
|
||||
data.num_channels = 1
|
||||
data.except_inout = False
|
||||
data.triu_adj = True
|
||||
data.connect_prev = False
|
||||
data.label_list = None
|
||||
data.tg_dataset = None
|
||||
data.node_rule_type = 2
|
||||
# aug_mask
|
||||
data.aug_mask_algo = 'none'
|
||||
|
||||
# model
|
||||
config.model = model = ml_collections.ConfigDict()
|
||||
model.name = 'CATE'
|
||||
model.ema_rate = 0.9999
|
||||
model.normalization = 'GroupNorm'
|
||||
model.nonlinearity = 'swish'
|
||||
model.nf = 128
|
||||
model.num_gnn_layers = 4
|
||||
model.size_cond = False
|
||||
model.embedding_type = 'positional'
|
||||
model.rw_depth = 16
|
||||
model.graph_layer = 'PosTransLayer'
|
||||
model.edge_th = -1.
|
||||
model.heads = 8
|
||||
model.attn_clamp = False
|
||||
|
||||
model.num_scales = 1000
|
||||
model.sigma_min = 0.1
|
||||
model.sigma_max = 1.0
|
||||
model.dropout = 0.1
|
||||
model.pos_enc_type = 2
|
||||
# graph encoder
|
||||
config.model.graph_encoder = graph_encoder = ml_collections.ConfigDict()
|
||||
graph_encoder.n_layers = 12
|
||||
graph_encoder.d_model = 64
|
||||
graph_encoder.n_head = 8
|
||||
graph_encoder.d_ff = 128
|
||||
graph_encoder.dropout = 0.1
|
||||
graph_encoder.n_vocab = 9 #10 # 30
|
||||
|
||||
# optimization
|
||||
config.optim = optim = ml_collections.ConfigDict()
|
||||
optim.weight_decay = 0
|
||||
optim.optimizer = 'Adam'
|
||||
optim.lr = 2e-5
|
||||
optim.beta1 = 0.9
|
||||
optim.eps = 1e-8
|
||||
optim.warmup = 1000
|
||||
optim.grad_clip = 1.
|
||||
|
||||
config.seed = 42
|
||||
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
# log
|
||||
config.log = log = ml_collections.ConfigDict()
|
||||
log.use_wandb = True
|
||||
log.wandb_project_name = 'DiffusionNAG'
|
||||
log.log_valid_sample_prop = False
|
||||
log.num_graphs_to_visualize = 20
|
||||
|
||||
return config
|
||||
493
MobileNetV3/datasets_nas.py
Normal file
493
MobileNetV3/datasets_nas.py
Normal file
@@ -0,0 +1,493 @@
|
||||
from __future__ import print_function
|
||||
import torch
|
||||
import os
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from torch_geometric.utils import to_networkx
|
||||
|
||||
from analysis.arch_functions import get_x_adj_from_opsdict_ofa, get_string_from_onehot_x
|
||||
from all_path import PROCESSED_DATA_PATH, SCORE_MODEL_DATA_IDX_PATH
|
||||
from analysis.arch_functions import OPS
|
||||
|
||||
|
||||
def get_data_scaler(config):
|
||||
"""Data normalizer. Assume data are always in [0, 1]."""
|
||||
|
||||
if config.data.centered:
|
||||
# Rescale to [-1, 1]
|
||||
return lambda x: x * 2. - 1.
|
||||
else:
|
||||
return lambda x: x
|
||||
|
||||
|
||||
def get_data_inverse_scaler(config):
|
||||
"""Inverse data normalizer."""
|
||||
|
||||
if config.data.centered:
|
||||
# Rescale [-1, 1] to [0, 1]
|
||||
return lambda x: (x + 1.) / 2.
|
||||
else:
|
||||
return lambda x: x
|
||||
|
||||
|
||||
def networkx_graphs(dataset):
|
||||
return [to_networkx(dataset[i], to_undirected=False, remove_self_loops=True) for i in range(len(dataset))]
|
||||
|
||||
|
||||
def get_dataloader(config, train_dataset, eval_dataset, test_dataset):
|
||||
train_loader = DataLoader(dataset=train_dataset,
|
||||
batch_size=config.training.batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn_ofa if config.model_type == 'meta_predictor' else None)
|
||||
eval_loader = DataLoader(dataset=eval_dataset,
|
||||
batch_size=config.training.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn_ofa if config.model_type == 'meta_predictor' else None)
|
||||
test_loader = DataLoader(dataset=test_dataset,
|
||||
batch_size=config.training.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn_ofa if config.model_type == 'meta_predictor' else None)
|
||||
|
||||
return train_loader, eval_loader, test_loader
|
||||
|
||||
|
||||
def get_dataloader_iter(config, train_dataset, eval_dataset, test_dataset):
|
||||
|
||||
train_loader = DataLoader(dataset=train_dataset,
|
||||
batch_size=config.training.batch_size if len(train_dataset) > config.training.batch_size else len(train_dataset),
|
||||
# batch_size=8,
|
||||
shuffle=True,)
|
||||
eval_loader = DataLoader(dataset=eval_dataset,
|
||||
batch_size=config.training.batch_size if len(eval_dataset) > config.training.batch_size else len(eval_dataset),
|
||||
# batch_size=8,
|
||||
shuffle=False,)
|
||||
test_loader = DataLoader(dataset=test_dataset,
|
||||
batch_size=config.training.batch_size if len(test_dataset) > config.training.batch_size else len(test_dataset),
|
||||
# batch_size=8,
|
||||
shuffle=False,)
|
||||
|
||||
return train_loader, eval_loader, test_loader
|
||||
|
||||
|
||||
def is_triu(mat):
|
||||
is_triu_ = np.allclose(mat, np.triu(mat))
|
||||
return is_triu_
|
||||
|
||||
|
||||
def collate_fn_ofa(batch):
|
||||
# x, adj, label_dict, task
|
||||
x = torch.stack([item[0] for item in batch])
|
||||
adj = torch.stack([item[1] for item in batch])
|
||||
label_dict = {}
|
||||
for item in batch:
|
||||
for k, v in item[2].items():
|
||||
if not k in label_dict.keys():
|
||||
label_dict[k] = []
|
||||
label_dict[k].append(v)
|
||||
for k, v in label_dict.items():
|
||||
label_dict[k] = torch.tensor(v)
|
||||
task = [item[3] for item in batch]
|
||||
return x, adj, label_dict, task
|
||||
|
||||
|
||||
def get_dataset(config):
|
||||
"""Create data loaders for training and evaluation.
|
||||
|
||||
Args:
|
||||
config: A ml_collection.ConfigDict parsed from config files.
|
||||
|
||||
Returns:
|
||||
train_ds, eval_ds, test_ds
|
||||
"""
|
||||
num_train = config.data.num_train if 'num_train' in config.data else None
|
||||
NASDataset = OFADataset
|
||||
|
||||
train_dataset = NASDataset(
|
||||
config.data.root,
|
||||
config.data.split_ratio,
|
||||
config.data.except_inout,
|
||||
config.data.triu_adj,
|
||||
config.data.connect_prev,
|
||||
'train',
|
||||
config.data.label_list,
|
||||
config.data.tg_dataset,
|
||||
config.data.dataset_idx,
|
||||
num_train,
|
||||
node_rule_type=config.data.node_rule_type)
|
||||
eval_dataset = NASDataset(
|
||||
config.data.root,
|
||||
config.data.split_ratio,
|
||||
config.data.except_inout,
|
||||
config.data.triu_adj,
|
||||
config.data.connect_prev,
|
||||
'eval',
|
||||
config.data.label_list,
|
||||
config.data.tg_dataset,
|
||||
config.data.dataset_idx,
|
||||
num_train,
|
||||
node_rule_type=config.data.node_rule_type)
|
||||
|
||||
test_dataset = NASDataset(
|
||||
config.data.root,
|
||||
config.data.split_ratio,
|
||||
config.data.except_inout,
|
||||
config.data.triu_adj,
|
||||
config.data.connect_prev,
|
||||
'test',
|
||||
config.data.label_list,
|
||||
config.data.tg_dataset,
|
||||
config.data.dataset_idx,
|
||||
num_train,
|
||||
node_rule_type=config.data.node_rule_type)
|
||||
|
||||
|
||||
return train_dataset, eval_dataset, test_dataset
|
||||
|
||||
|
||||
def get_meta_dataset(config):
|
||||
database = MetaTrainDatabaseOFA
|
||||
data_path = PROCESSED_DATA_PATH
|
||||
|
||||
train_dataset = database(
|
||||
data_path,
|
||||
config.model.num_sample,
|
||||
config.data.label_list,
|
||||
True,
|
||||
config.data.except_inout,
|
||||
config.data.triu_adj,
|
||||
config.data.connect_prev,
|
||||
'train')
|
||||
eval_dataset = database(
|
||||
data_path,
|
||||
config.model.num_sample,
|
||||
config.data.label_list,
|
||||
True,
|
||||
config.data.except_inout,
|
||||
config.data.triu_adj,
|
||||
config.data.connect_prev,
|
||||
'val')
|
||||
# test_dataset = MetaTestDataset()
|
||||
test_dataset = None
|
||||
return train_dataset, eval_dataset, test_dataset
|
||||
|
||||
def get_meta_dataloader(config ,train_dataset, eval_dataset, test_dataset):
|
||||
if config.data.name == 'ofa':
|
||||
train_loader = DataLoader(dataset=train_dataset,
|
||||
batch_size=config.training.batch_size,
|
||||
shuffle=True,)
|
||||
# collate_fn=collate_fn_ofa)
|
||||
eval_loader = DataLoader(dataset=eval_dataset,
|
||||
batch_size=config.training.batch_size,)
|
||||
# collate_fn=collate_fn_ofa)
|
||||
else:
|
||||
train_loader = DataLoader(dataset=train_dataset,
|
||||
batch_size=config.training.batch_size,
|
||||
shuffle=True)
|
||||
eval_loader = DataLoader(dataset=eval_dataset,
|
||||
batch_size=config.training.batch_size,
|
||||
shuffle=False)
|
||||
# test_loader = DataLoader(dataset=test_dataset,
|
||||
# batch_size=config.training.batch_size,
|
||||
# shuffle=False)
|
||||
test_loader = None
|
||||
return train_loader, eval_loader, test_loader
|
||||
|
||||
|
||||
class MetaTestDataset(Dataset):
|
||||
def __init__(self, data_path, data_name, num_sample, num_class=None):
|
||||
self.num_sample = num_sample
|
||||
self.data_name = data_name
|
||||
|
||||
num_class_dict = {
|
||||
'cifar100': 100,
|
||||
'cifar10': 10,
|
||||
'mnist': 10,
|
||||
'svhn': 10,
|
||||
'aircraft': 30,
|
||||
'pets': 37
|
||||
}
|
||||
|
||||
if num_class is not None:
|
||||
self.num_class = num_class
|
||||
else:
|
||||
self.num_class = num_class_dict[data_name]
|
||||
self.x = torch.load(os.path.join(data_path, f'aircraft100bylabel.pt' if 'ofa' in data_path and data_name == 'aircraft' else f'{data_name}bylabel.pt' ))
|
||||
|
||||
def __len__(self):
|
||||
return 1000000
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = []
|
||||
classes = list(range(self.num_class))
|
||||
for cls in classes:
|
||||
cx = self.x[cls][0]
|
||||
ridx = torch.randperm(len(cx))
|
||||
data.append(cx[ridx[:self.num_sample]])
|
||||
x = torch.cat(data)
|
||||
return x
|
||||
|
||||
|
||||
class MetaTrainDatabaseOFA(Dataset):
|
||||
# def __init__(self, data_path, num_sample, is_pred=False):
|
||||
def __init__(
|
||||
self,
|
||||
data_path,
|
||||
num_sample,
|
||||
label_list,
|
||||
is_pred=True,
|
||||
except_inout=False,
|
||||
triu_adj=True,
|
||||
connect_prev=False,
|
||||
mode='train'):
|
||||
|
||||
self.ops_decoder = list(OPS.keys())
|
||||
self.mode = mode
|
||||
self.acc_norm = True
|
||||
self.num_sample = num_sample
|
||||
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
|
||||
|
||||
if is_pred:
|
||||
self.dpath = f'{data_path}/predictor/processed/'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.dname = 'database_219152_14.0K'
|
||||
data = torch.load(self.dpath + f'{self.dname}_{self.mode}.pt')
|
||||
self.net = data['net']
|
||||
self.x_list = []
|
||||
self.adj_list = []
|
||||
self.arch_str_list = []
|
||||
for net in self.net:
|
||||
x, adj = get_x_adj_from_opsdict_ofa(net)
|
||||
# ---------- matrix ---------- #
|
||||
self.x_list.append(x)
|
||||
self.adj_list.append(torch.tensor(adj))
|
||||
# ---------- arch_str ---------- #
|
||||
self.arch_str_list.append(get_string_from_onehot_x(x))
|
||||
# ---------- labels ---------- #
|
||||
self.label_list = label_list
|
||||
if self.label_list is not None:
|
||||
self.flops_list = data['flops']
|
||||
self.params_list = None
|
||||
self.latency_list = None
|
||||
|
||||
self.acc_list = data['acc']
|
||||
self.mean = data['mean']
|
||||
self.std = data['std']
|
||||
self.task_lst = data['class']
|
||||
|
||||
def __len__(self):
|
||||
return len(self.acc_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = []
|
||||
classes = self.task_lst[index]
|
||||
acc = self.acc_list[index]
|
||||
graph = self.net[index]
|
||||
|
||||
# ---------- x -----------
|
||||
x = self.x_list[index]
|
||||
# ---------- adj ----------
|
||||
adj = self.adj_list[index]
|
||||
acc = self.acc_list[index]
|
||||
|
||||
for i, cls in enumerate(classes):
|
||||
cx = self.x[cls.item()][0]
|
||||
ridx = torch.randperm(len(cx))
|
||||
data.append(cx[ridx[:self.num_sample]])
|
||||
task = torch.cat(data)
|
||||
if self.acc_norm:
|
||||
acc = ((acc - self.mean) / self.std) / 100.0
|
||||
else:
|
||||
acc = acc / 100.0
|
||||
|
||||
label_dict = {}
|
||||
if self.label_list is not None:
|
||||
assert type(self.label_list) == list
|
||||
for label in self.label_list:
|
||||
if label == 'meta-acc':
|
||||
label_dict[f"{label}"] = acc
|
||||
else:
|
||||
raise ValueError
|
||||
return x, adj, label_dict, task
|
||||
|
||||
|
||||
class OFADataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_path,
|
||||
split_ratio=0.8,
|
||||
except_inout=False,
|
||||
triu_adj=True,
|
||||
connect_prev=False,
|
||||
mode='train',
|
||||
label_list=None,
|
||||
tg_dataset=None,
|
||||
dataset_idx='random',
|
||||
num_train=None,
|
||||
node_rule_type=None):
|
||||
|
||||
# ---------- entire dataset ---------- #
|
||||
self.data = torch.load(data_path)
|
||||
self.except_inout = except_inout
|
||||
self.triu_adj = triu_adj
|
||||
self.connect_prev = connect_prev
|
||||
self.node_rule_type = node_rule_type
|
||||
|
||||
# ---------- x ---------- #
|
||||
self.x_list = self.data['x_none2zero']
|
||||
|
||||
# ---------- adj ---------- #
|
||||
assert self.connect_prev == False
|
||||
self.n_adj = len(self.data['node_type'][0])
|
||||
const_adj = self.get_not_connect_prev_adj()
|
||||
self.adj_list = [const_adj] * len(self.x_list)
|
||||
|
||||
# ---------- arch_str ---------- #
|
||||
self.arch_str_list = self.data['net_setting']
|
||||
# ---------- labels ---------- #
|
||||
self.label_list = label_list
|
||||
if self.label_list is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
# ----------- split dataset ---------- #
|
||||
self.ds_idx = list(torch.load(SCORE_MODEL_DATA_IDX_PATH))
|
||||
|
||||
self.split_ratio = split_ratio
|
||||
if num_train is None:
|
||||
num_train = int(len(self.x_list) * self.split_ratio)
|
||||
num_test = len(self.x_list) - num_train
|
||||
else:
|
||||
num_train = num_train
|
||||
num_test = len(self.x_list) - num_train
|
||||
# ----------- compute mean and std w/ training dataset ---------- #
|
||||
if self.label_list is not None:
|
||||
self.train_idx_list = self.ds_idx[:num_train]
|
||||
print('Computing mean and std of the training set...')
|
||||
from collections import defaultdict
|
||||
LABEL_TO_MEAN_STD = defaultdict(dict)
|
||||
assert type(self.label_list) == list
|
||||
for label in self.label_list:
|
||||
if label == 'test-acc':
|
||||
self.test_acc_list_tr = [self.test_acc_list[i] for i in self.train_idx_list]
|
||||
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.test_acc_list_tr))
|
||||
elif label == 'flops':
|
||||
self.flops_list_tr = [self.flops_list[i] for i in self.train_idx_list]
|
||||
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.flops_list_tr))
|
||||
elif label == 'params':
|
||||
self.params_list_tr = [self.params_list[i] for i in self.train_idx_list]
|
||||
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.params_list_tr))
|
||||
elif label == 'latency':
|
||||
self.latency_list_tr = [self.latency_list[i] for i in self.train_idx_list]
|
||||
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.latency_list_tr))
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self.mode = mode
|
||||
if self.mode in ['train']:
|
||||
self.idx_list = self.ds_idx[:num_train]
|
||||
elif self.mode in ['eval']:
|
||||
self.idx_list = self.ds_idx[:num_test]
|
||||
elif self.mode in ['test']:
|
||||
self.idx_list = self.ds_idx[num_train:]
|
||||
|
||||
self.x_list_ = [self.x_list[i] for i in self.idx_list]
|
||||
self.adj_list_ = [self.adj_list[i] for i in self.idx_list]
|
||||
self.arch_str_list_ = [self.arch_str_list[i] for i in self.idx_list]
|
||||
|
||||
if self.label_list is not None:
|
||||
assert type(self.label_list) == list
|
||||
for label in self.label_list:
|
||||
if label == 'test-acc':
|
||||
self.test_acc_list_ = [self.test_acc_list[i] for i in self.idx_list]
|
||||
self.test_acc_list_ = self.normalize(self.test_acc_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
|
||||
elif label == 'flops':
|
||||
self.flops_list_ = [self.flops_list[i] for i in self.idx_list]
|
||||
self.flops_list_ = self.normalize(self.flops_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
|
||||
elif label == 'params':
|
||||
self.params_list_ = [self.params_list[i] for i in self.idx_list]
|
||||
self.params_list_ = self.normalize(self.params_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
|
||||
elif label == 'latency':
|
||||
self.latency_list_ = [self.latency_list[i] for i in self.idx_list]
|
||||
self.latency_list_ = self.normalize(self.latency_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def normalize(self, original, mean, std):
|
||||
return [(i-mean)/std for i in original]
|
||||
|
||||
def get_not_connect_prev_adj(self):
|
||||
_adj = torch.zeros(self.n_adj, self.n_adj)
|
||||
for i in range(self.n_adj-1):
|
||||
_adj[i, i+1] = 1
|
||||
_adj = _adj.to(torch.float32).to('cpu') # torch.tensor(_adj, dtype=torch.float32, device=torch.device('cpu'))
|
||||
# if self.except_inout:
|
||||
# _adj = _adj[1:-1, 1:-1]
|
||||
return _adj
|
||||
|
||||
@property
|
||||
def adj(self):
|
||||
return self.adj_list_[0]
|
||||
|
||||
# @property
|
||||
def mask(self, algo='floyd', data='ofa'):
|
||||
from utils import aug_mask
|
||||
return aug_mask(self.adj, algo=algo, data=data)[0]
|
||||
|
||||
def get_unnoramlized_entire_data(self, label, tg_dataset):
|
||||
entire_test_acc_list = self.data['test-acc'][tg_dataset]
|
||||
entire_flops_list = self.data['flops'][tg_dataset]
|
||||
entire_params_list = self.data['params'][tg_dataset]
|
||||
entire_latency_list = self.data['latency'][tg_dataset]
|
||||
|
||||
if label == 'test-acc':
|
||||
return entire_test_acc_list
|
||||
elif label == 'flops':
|
||||
return entire_flops_list
|
||||
elif label == 'params':
|
||||
return entire_params_list
|
||||
elif label == 'latency':
|
||||
return entire_latency_list
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def get_unnoramlized_data(self, label, tg_dataset):
|
||||
entire_test_acc_list = self.data['test-acc'][tg_dataset]
|
||||
entire_flops_list = self.data['flops'][tg_dataset]
|
||||
entire_params_list = self.data['params'][tg_dataset]
|
||||
entire_latency_list = self.data['latency'][tg_dataset]
|
||||
|
||||
if label == 'test-acc':
|
||||
return [entire_test_acc_list[i] for i in self.idx_list]
|
||||
elif label == 'flops':
|
||||
return [entire_flops_list[i] for i in self.idx_list]
|
||||
elif label == 'params':
|
||||
return [entire_params_list[i] for i in self.idx_list]
|
||||
elif label == 'latency':
|
||||
return [entire_latency_list[i] for i in self.idx_list]
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def __len__(self):
|
||||
return len(self.x_list_)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
label_dict = {}
|
||||
if self.label_list is not None:
|
||||
assert type(self.label_list) == list
|
||||
for label in self.label_list:
|
||||
if label == 'test-acc':
|
||||
label_dict[f"{label}"] = self.test_acc_list_[index]
|
||||
elif label == 'flops':
|
||||
label_dict[f"{label}"] = self.flops_list_[index]
|
||||
elif label == 'params':
|
||||
label_dict[f"{label}"] = self.params_list_[index]
|
||||
elif label == 'latency':
|
||||
label_dict[f"{label}"] = self.latency_list_[index]
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
return self.x_list_[index], self.adj_list_[index], label_dict
|
||||
1
MobileNetV3/evaluation/__init__.py
Executable file
1
MobileNetV3/evaluation/__init__.py
Executable file
@@ -0,0 +1 @@
|
||||
from .evaluator import get_stats_eval, get_nn_eval
|
||||
58
MobileNetV3/evaluation/evaluator.py
Normal file
58
MobileNetV3/evaluation/evaluator.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import networkx as nx
|
||||
from .structure_evaluator import mmd_eval
|
||||
from .gin_evaluator import nn_based_eval
|
||||
from torch_geometric.utils import to_networkx
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import dgl
|
||||
|
||||
|
||||
def get_stats_eval(config):
|
||||
|
||||
if config.eval.mmd_distance.lower() == 'rbf':
|
||||
method = [('degree', 1., 'argmax'), ('cluster', 0.1, 'argmax'),
|
||||
('spectral', 1., 'argmax')]
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def eval_stats_fn(test_dataset, pred_graph_list):
|
||||
pred_G = [nx.from_numpy_matrix(pred_adj) for pred_adj in pred_graph_list]
|
||||
sub_pred_G = []
|
||||
if config.eval.max_subgraph:
|
||||
for G in pred_G:
|
||||
CGs = [G.subgraph(c) for c in nx.connected_components(G)]
|
||||
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
|
||||
sub_pred_G += [CGs[0]]
|
||||
pred_G = sub_pred_G
|
||||
|
||||
test_G = [to_networkx(test_dataset[i], to_undirected=True, remove_self_loops=True)
|
||||
for i in range(len(test_dataset))]
|
||||
results = mmd_eval(test_G, pred_G, method)
|
||||
return results
|
||||
|
||||
return eval_stats_fn
|
||||
|
||||
|
||||
def get_nn_eval(config):
|
||||
|
||||
if hasattr(config.eval, "N_gin"):
|
||||
N_gin = config.eval.N_gin
|
||||
else:
|
||||
N_gin = 10
|
||||
|
||||
def nn_eval_fn(test_dataset, pred_graph_list):
|
||||
pred_G = [nx.from_numpy_matrix(pred_adj) for pred_adj in pred_graph_list]
|
||||
sub_pred_G = []
|
||||
if config.eval.max_subgraph:
|
||||
for G in pred_G:
|
||||
CGs = [G.subgraph(c) for c in nx.connected_components(G)]
|
||||
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
|
||||
sub_pred_G += [CGs[0]]
|
||||
pred_G = sub_pred_G
|
||||
test_G = [to_networkx(test_dataset[i], to_undirected=True, remove_self_loops=True)
|
||||
for i in range(len(test_dataset))]
|
||||
|
||||
results = nn_based_eval(test_G, pred_G, N_gin)
|
||||
return results
|
||||
|
||||
return nn_eval_fn
|
||||
311
MobileNetV3/evaluation/gin.py
Normal file
311
MobileNetV3/evaluation/gin.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""Modified from https://github.com/uoguelph-mlrg/GGM-metrics"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import dgl.function as fn
|
||||
from dgl.utils import expand_as_pair
|
||||
from dgl.nn import SumPooling, AvgPooling, MaxPooling
|
||||
|
||||
|
||||
class GINConv(nn.Module):
|
||||
def __init__(self,
|
||||
apply_func,
|
||||
aggregator_type,
|
||||
init_eps=0,
|
||||
learn_eps=False):
|
||||
super(GINConv, self).__init__()
|
||||
self.apply_func = apply_func
|
||||
self._aggregator_type = aggregator_type
|
||||
if aggregator_type == 'sum':
|
||||
self._reducer = fn.sum
|
||||
elif aggregator_type == 'max':
|
||||
self._reducer = fn.max
|
||||
elif aggregator_type == 'mean':
|
||||
self._reducer = fn.mean
|
||||
else:
|
||||
raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type))
|
||||
# to specify whether eps is trainable or not.
|
||||
if learn_eps:
|
||||
self.eps = torch.nn.Parameter(torch.FloatTensor([init_eps]))
|
||||
else:
|
||||
self.register_buffer('eps', torch.FloatTensor([init_eps]))
|
||||
|
||||
def forward(self, graph, feat, edge_weight=None):
|
||||
r"""
|
||||
Description
|
||||
-----------
|
||||
Compute Graph Isomorphism Network layer.
|
||||
Parameters
|
||||
----------
|
||||
graph : DGLGraph
|
||||
The graph.
|
||||
feat : torch.Tensor or pair of torch.Tensor
|
||||
If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
|
||||
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
|
||||
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
|
||||
:math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.
|
||||
If ``apply_func`` is not None, :math:`D_{in}` should
|
||||
fit the input dimensionality requirement of ``apply_func``.
|
||||
edge_weight : torch.Tensor, optional
|
||||
Optional tensor on the edge. If given, the convolution will weight
|
||||
with regard to the message.
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The output feature of shape :math:`(N, D_{out})` where
|
||||
:math:`D_{out}` is the output dimensionality of ``apply_func``.
|
||||
If ``apply_func`` is None, :math:`D_{out}` should be the same
|
||||
as input dimensionality.
|
||||
"""
|
||||
with graph.local_scope():
|
||||
aggregate_fn = self.concat_edge_msg
|
||||
# aggregate_fn = fn.copy_src('h', 'm')
|
||||
if edge_weight is not None:
|
||||
assert edge_weight.shape[0] == graph.number_of_edges()
|
||||
graph.edata['_edge_weight'] = edge_weight
|
||||
aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')
|
||||
|
||||
feat_src, feat_dst = expand_as_pair(feat, graph)
|
||||
graph.srcdata['h'] = feat_src
|
||||
graph.update_all(aggregate_fn, self._reducer('m', 'neigh'))
|
||||
|
||||
|
||||
diff = torch.tensor(graph.dstdata['neigh'].shape[1: ]) - torch.tensor(feat_dst.shape[1: ])
|
||||
zeros = torch.zeros(feat_dst.shape[0], *diff).to(feat_dst.device)
|
||||
feat_dst = torch.cat([feat_dst, zeros], dim=1)
|
||||
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
|
||||
if self.apply_func is not None:
|
||||
rst = self.apply_func(rst)
|
||||
return rst
|
||||
|
||||
def concat_edge_msg(self, edges):
|
||||
if self.edge_feat_loc not in edges.data:
|
||||
return {'m': edges.src['h']}
|
||||
else:
|
||||
m = torch.cat([edges.src['h'], edges.data[self.edge_feat_loc]], dim=1)
|
||||
return {'m': m}
|
||||
|
||||
|
||||
class ApplyNodeFunc(nn.Module):
|
||||
"""Update the node feature hv with MLP, BN and ReLU."""
|
||||
def __init__(self, mlp):
|
||||
super(ApplyNodeFunc, self).__init__()
|
||||
self.mlp = mlp
|
||||
self.bn = nn.BatchNorm1d(self.mlp.output_dim)
|
||||
|
||||
def forward(self, h):
|
||||
h = self.mlp(h)
|
||||
h = self.bn(h)
|
||||
h = F.relu(h)
|
||||
return h
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""MLP with linear output"""
|
||||
def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
|
||||
"""MLP layers construction
|
||||
|
||||
Paramters
|
||||
---------
|
||||
num_layers: int
|
||||
The number of linear layers
|
||||
input_dim: int
|
||||
The dimensionality of input features
|
||||
hidden_dim: int
|
||||
The dimensionality of hidden units at ALL layers
|
||||
output_dim: int
|
||||
The number of classes for prediction
|
||||
|
||||
"""
|
||||
super(MLP, self).__init__()
|
||||
self.linear_or_not = True # default is linear model
|
||||
self.num_layers = num_layers
|
||||
self.output_dim = output_dim
|
||||
|
||||
if num_layers < 1:
|
||||
raise ValueError("number of layers should be positive!")
|
||||
elif num_layers == 1:
|
||||
# Linear model
|
||||
self.linear = nn.Linear(input_dim, output_dim)
|
||||
|
||||
else:
|
||||
# Multi-layer model
|
||||
self.linear_or_not = False
|
||||
self.linears = torch.nn.ModuleList()
|
||||
self.batch_norms = torch.nn.ModuleList()
|
||||
|
||||
self.linears.append(nn.Linear(input_dim, hidden_dim))
|
||||
for layer in range(num_layers - 2):
|
||||
self.linears.append(nn.Linear(hidden_dim, hidden_dim))
|
||||
self.linears.append(nn.Linear(hidden_dim, output_dim))
|
||||
|
||||
for layer in range(num_layers - 1):
|
||||
self.batch_norms.append(nn.BatchNorm1d((hidden_dim)))
|
||||
|
||||
def forward(self, x):
|
||||
if self.linear_or_not:
|
||||
# If linear model
|
||||
return self.linear(x)
|
||||
else:
|
||||
# If MLP
|
||||
h = x
|
||||
for i in range(self.num_layers - 1):
|
||||
h = F.relu(self.batch_norms[i](self.linears[i](h)))
|
||||
return self.linears[-1](h)
|
||||
|
||||
|
||||
class GIN(nn.Module):
|
||||
"""GIN model"""
|
||||
def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim,
|
||||
graph_pooling_type, neighbor_pooling_type, edge_feat_dim=0,
|
||||
final_dropout=0.0, learn_eps=False, output_dim=1, **kwargs):
|
||||
"""model parameters setting
|
||||
|
||||
Paramters
|
||||
---------
|
||||
num_layers: int
|
||||
The number of linear layers in the neural network
|
||||
num_mlp_layers: int
|
||||
The number of linear layers in mlps
|
||||
input_dim: int
|
||||
The dimensionality of input features
|
||||
hidden_dim: int
|
||||
The dimensionality of hidden units at ALL layers
|
||||
output_dim: int
|
||||
The number of classes for prediction
|
||||
final_dropout: float
|
||||
dropout ratio on the final linear layer
|
||||
learn_eps: boolean
|
||||
If True, learn epsilon to distinguish center nodes from neighbors
|
||||
If False, aggregate neighbors and center nodes altogether.
|
||||
neighbor_pooling_type: str
|
||||
how to aggregate neighbors (sum, mean, or max)
|
||||
graph_pooling_type: str
|
||||
how to aggregate entire nodes in a graph (sum, mean or max)
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
def init_weights_orthogonal(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
elif isinstance(m, MLP):
|
||||
if hasattr(m, 'linears'):
|
||||
m.linears.apply(init_weights_orthogonal)
|
||||
else:
|
||||
m.linear.apply(init_weights_orthogonal)
|
||||
elif isinstance(m, nn.ModuleList):
|
||||
pass
|
||||
else:
|
||||
raise Exception()
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.learn_eps = learn_eps
|
||||
|
||||
# List of MLPs
|
||||
self.ginlayers = torch.nn.ModuleList()
|
||||
self.batch_norms = torch.nn.ModuleList()
|
||||
|
||||
# self.preprocess_nodes = PreprocessNodeAttrs(
|
||||
# node_attrs=node_preprocess, output_dim=node_preprocess_output_dim)
|
||||
# print(input_dim)
|
||||
for layer in range(self.num_layers - 1):
|
||||
if layer == 0:
|
||||
mlp = MLP(num_mlp_layers, input_dim + edge_feat_dim, hidden_dim, hidden_dim)
|
||||
else:
|
||||
mlp = MLP(num_mlp_layers, hidden_dim + edge_feat_dim, hidden_dim, hidden_dim)
|
||||
if kwargs['init'] == 'orthogonal':
|
||||
init_weights_orthogonal(mlp)
|
||||
|
||||
self.ginlayers.append(
|
||||
GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps))
|
||||
self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
|
||||
|
||||
# Linear function for graph poolings of output of each layer
|
||||
# which maps the output of different layers into a prediction score
|
||||
self.linears_prediction = torch.nn.ModuleList()
|
||||
|
||||
for layer in range(num_layers):
|
||||
if layer == 0:
|
||||
self.linears_prediction.append(
|
||||
nn.Linear(input_dim, output_dim))
|
||||
else:
|
||||
self.linears_prediction.append(
|
||||
nn.Linear(hidden_dim, output_dim))
|
||||
|
||||
if kwargs['init'] == 'orthogonal':
|
||||
# print('orthogonal')
|
||||
self.linears_prediction.apply(init_weights_orthogonal)
|
||||
|
||||
self.drop = nn.Dropout(final_dropout)
|
||||
|
||||
if graph_pooling_type == 'sum':
|
||||
self.pool = SumPooling()
|
||||
elif graph_pooling_type == 'mean':
|
||||
self.pool = AvgPooling()
|
||||
elif graph_pooling_type == 'max':
|
||||
self.pool = MaxPooling()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, g, h):
|
||||
# list of hidden representation at each layer (including input)
|
||||
hidden_rep = [h]
|
||||
|
||||
# h = self.preprocess_nodes(h)
|
||||
for i in range(self.num_layers - 1):
|
||||
h = self.ginlayers[i](g, h)
|
||||
h = self.batch_norms[i](h)
|
||||
h = F.relu(h)
|
||||
hidden_rep.append(h)
|
||||
|
||||
score_over_layer = 0
|
||||
|
||||
# perform pooling over all nodes in each graph in every layer
|
||||
for i, h in enumerate(hidden_rep):
|
||||
pooled_h = self.pool(g, h)
|
||||
score_over_layer += self.drop(self.linears_prediction[i](pooled_h))
|
||||
return score_over_layer
|
||||
|
||||
def get_graph_embed(self, g, h):
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
# return self.forward(g, h).detach().numpy()
|
||||
hidden_rep = []
|
||||
# h = self.preprocess_nodes(h)
|
||||
for i in range(self.num_layers - 1):
|
||||
h = self.ginlayers[i](g, h)
|
||||
h = self.batch_norms[i](h)
|
||||
h = F.relu(h)
|
||||
hidden_rep.append(h)
|
||||
|
||||
# perform pooling over all nodes in each graph in every layer
|
||||
graph_embed = torch.Tensor([]).to(self.device)
|
||||
for i, h in enumerate(hidden_rep):
|
||||
pooled_h = self.pool(g, h)
|
||||
graph_embed = torch.cat([graph_embed, pooled_h], dim = 1)
|
||||
|
||||
return graph_embed
|
||||
|
||||
def get_graph_embed_no_cat(self, g, h):
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
hidden_rep = []
|
||||
# h = self.preprocess_nodes(h)
|
||||
for i in range(self.num_layers - 1):
|
||||
h = self.ginlayers[i](g, h)
|
||||
h = self.batch_norms[i](h)
|
||||
h = F.relu(h)
|
||||
hidden_rep.append(h)
|
||||
|
||||
return self.pool(g, hidden_rep[-1]).to(self.device)
|
||||
|
||||
@property
|
||||
def edge_feat_loc(self):
|
||||
return self.ginlayers[0].edge_feat_loc
|
||||
|
||||
@edge_feat_loc.setter
|
||||
def edge_feat_loc(self, loc):
|
||||
for layer in self.ginlayers:
|
||||
layer.edge_feat_loc = loc
|
||||
292
MobileNetV3/evaluation/gin_evaluator.py
Normal file
292
MobileNetV3/evaluation/gin_evaluator.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""Evaluation on random GIN features. Modified from https://github.com/uoguelph-mlrg/GGM-metrics"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import sklearn
|
||||
import sklearn.metrics
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
import time
|
||||
import dgl
|
||||
|
||||
from .gin import GIN
|
||||
|
||||
|
||||
def load_feature_extractor(
|
||||
device, num_layers=3, hidden_dim=35, neighbor_pooling_type='sum',
|
||||
graph_pooling_type='sum', input_dim=1, edge_feat_dim=0,
|
||||
dont_concat=False, num_mlp_layers=2, output_dim=1,
|
||||
node_feat_loc='attr', edge_feat_loc='attr', init='orthogonal',
|
||||
**kwargs):
|
||||
|
||||
model = GIN(num_layers=num_layers, hidden_dim=hidden_dim, neighbor_pooling_type=neighbor_pooling_type,
|
||||
graph_pooling_type=graph_pooling_type, input_dim=input_dim, edge_feat_dim=edge_feat_dim,
|
||||
num_mlp_layers=num_mlp_layers, output_dim=output_dim, init=init)
|
||||
|
||||
model.node_feat_loc = node_feat_loc
|
||||
model.edge_feat_loc = edge_feat_loc
|
||||
|
||||
model.eval()
|
||||
|
||||
if dont_concat:
|
||||
model.forward = model.get_graph_embed_no_cat
|
||||
else:
|
||||
model.forward = model.get_graph_embed
|
||||
|
||||
model.device = device
|
||||
return model.to(device)
|
||||
|
||||
|
||||
def time_function(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
start = time.time()
|
||||
results = func(*args, **kwargs)
|
||||
end = time.time()
|
||||
return results, end - start
|
||||
return wrapper
|
||||
|
||||
|
||||
class GINMetric():
|
||||
def __init__(self, model):
|
||||
self.feat_extractor = model
|
||||
self.get_activations = self.get_activations_gin
|
||||
|
||||
@time_function
|
||||
def get_activations_gin(self, generated_dataset, reference_dataset):
|
||||
return self._get_activations(generated_dataset, reference_dataset)
|
||||
|
||||
def _get_activations(self, generated_dataset, reference_dataset):
|
||||
gen_activations = self.__get_activations_single_dataset(generated_dataset)
|
||||
ref_activations = self.__get_activations_single_dataset(reference_dataset)
|
||||
|
||||
scaler = StandardScaler()
|
||||
scaler.fit(ref_activations)
|
||||
ref_activations = scaler.transform(ref_activations)
|
||||
gen_activations = scaler.transform(gen_activations)
|
||||
|
||||
return gen_activations, ref_activations
|
||||
|
||||
def __get_activations_single_dataset(self, dataset):
|
||||
|
||||
node_feat_loc = self.feat_extractor.node_feat_loc
|
||||
edge_feat_loc = self.feat_extractor.edge_feat_loc
|
||||
|
||||
ndata = [node_feat_loc] if node_feat_loc in dataset[0].ndata else '__ALL__'
|
||||
edata = [edge_feat_loc] if edge_feat_loc in dataset[0].edata else '__ALL__'
|
||||
graphs = dgl.batch(dataset, ndata=ndata, edata=edata).to(self.feat_extractor.device)
|
||||
|
||||
if node_feat_loc not in graphs.ndata: # Use degree as features
|
||||
feats = graphs.in_degrees() + graphs.out_degrees()
|
||||
feats = feats.unsqueeze(1).type(torch.float32)
|
||||
else:
|
||||
feats = graphs.ndata[node_feat_loc]
|
||||
|
||||
graph_embeds = self.feat_extractor(graphs, feats)
|
||||
return graph_embeds.cpu().detach().numpy()
|
||||
|
||||
def evaluate(self, *args, **kwargs):
|
||||
raise Exception('Must be implemented by child class')
|
||||
|
||||
|
||||
class MMDEvaluation(GINMetric):
|
||||
def __init__(self, model, kernel='rbf', sigma='range', multiplier='mean'):
|
||||
super().__init__(model)
|
||||
|
||||
if multiplier == 'mean':
|
||||
self.__get_sigma_mult_factor = self.__mean_pairwise_distance
|
||||
elif multiplier == 'median':
|
||||
self.__get_sigma_mult_factor = self.__median_pairwise_distance
|
||||
elif multiplier is None:
|
||||
self.__get_sigma_mult_factor = lambda *args, **kwargs: 1
|
||||
else:
|
||||
raise Exception(multiplier)
|
||||
|
||||
if 'rbf' in kernel:
|
||||
if sigma == 'range':
|
||||
self.base_sigmas = np.array([0.01, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0])
|
||||
|
||||
if multiplier == 'mean':
|
||||
self.name = 'mmd_rbf'
|
||||
elif multiplier == 'median':
|
||||
self.name = 'mmd_rbf_adaptive_median'
|
||||
else:
|
||||
self.name = 'mmd_rbf_adaptive'
|
||||
elif sigma == 'one':
|
||||
self.base_sigmas = np.array([1])
|
||||
|
||||
if multiplier == 'mean':
|
||||
self.name = 'mmd_rbf_single_mean'
|
||||
elif multiplier == 'median':
|
||||
self.name = 'mmd_rbf_single_median'
|
||||
else:
|
||||
self.name = 'mmd_rbf_single'
|
||||
else:
|
||||
raise Exception(sigma)
|
||||
|
||||
self.evaluate = self.calculate_MMD_rbf_quadratic
|
||||
|
||||
elif 'linear' in kernel:
|
||||
self.evaluate = self.calculate_MMD_linear_kernel
|
||||
|
||||
else:
|
||||
raise Exception()
|
||||
|
||||
def __get_pairwise_distances(self, generated_dataset, reference_dataset):
|
||||
return sklearn.metrics.pairwise_distances(reference_dataset, generated_dataset, metric='euclidean', n_jobs=8)**2
|
||||
|
||||
def __mean_pairwise_distance(self, dists_GR):
|
||||
return np.sqrt(dists_GR.mean())
|
||||
|
||||
def __median_pairwise_distance(self, dists_GR):
|
||||
return np.sqrt(np.median(dists_GR))
|
||||
|
||||
def get_sigmas(self, dists_GR):
|
||||
mult_factor = self.__get_sigma_mult_factor(dists_GR)
|
||||
return self.base_sigmas * mult_factor
|
||||
|
||||
@time_function
|
||||
def calculate_MMD_rbf_quadratic(self, generated_dataset=None, reference_dataset=None):
|
||||
# https://github.com/djsutherland/opt-mmd/blob/master/two_sample/mmd.py
|
||||
|
||||
if not isinstance(generated_dataset, torch.Tensor) and not isinstance(generated_dataset, np.ndarray):
|
||||
(generated_dataset, reference_dataset), _ = self.get_activations(generated_dataset, reference_dataset)
|
||||
|
||||
GG = self.__get_pairwise_distances(generated_dataset, generated_dataset)
|
||||
GR = self.__get_pairwise_distances(generated_dataset, reference_dataset)
|
||||
RR = self.__get_pairwise_distances(reference_dataset, reference_dataset)
|
||||
|
||||
max_mmd = 0
|
||||
sigmas = self.get_sigmas(GR)
|
||||
|
||||
for sigma in sigmas:
|
||||
gamma = 1 / (2 * sigma**2)
|
||||
|
||||
K_GR = np.exp(-gamma * GR)
|
||||
K_GG = np.exp(-gamma * GG)
|
||||
K_RR = np.exp(-gamma * RR)
|
||||
|
||||
mmd = K_GG.mean() + K_RR.mean() - 2 * K_GR.mean()
|
||||
max_mmd = mmd if mmd > max_mmd else max_mmd
|
||||
|
||||
return {self.name: max_mmd}
|
||||
|
||||
@time_function
|
||||
def calculate_MMD_linear_kernel(self, generated_dataset=None, reference_dataset=None):
|
||||
# https://github.com/djsutherland/opt-mmd/blob/master/two_sample/mmd.py
|
||||
if not isinstance(generated_dataset, torch.Tensor) and not isinstance(generated_dataset, np.ndarray):
|
||||
(generated_dataset, reference_dataset), _ = self.get_activations(generated_dataset, reference_dataset)
|
||||
|
||||
G_bar = generated_dataset.mean(axis=0)
|
||||
R_bar = reference_dataset.mean(axis=0)
|
||||
Z_bar = G_bar - R_bar
|
||||
mmd = Z_bar.dot(Z_bar)
|
||||
mmd = mmd if mmd >= 0 else 0
|
||||
return {'mmd_linear': mmd}
|
||||
|
||||
|
||||
class prdcEvaluation(GINMetric):
|
||||
# From PRDC github: https://github.com/clovaai/generative-evaluation-prdc/blob/master/prdc/prdc.py#L54
|
||||
def __init__(self, *args, use_pr=False, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.use_pr = use_pr
|
||||
|
||||
@time_function
|
||||
def evaluate(self, generated_dataset=None, reference_dataset=None, nearest_k=5):
|
||||
""" Computes precision, recall, density, and coverage given two manifolds. """
|
||||
|
||||
if not isinstance(generated_dataset, torch.Tensor) and not isinstance(generated_dataset, np.ndarray):
|
||||
(generated_dataset, reference_dataset), _ = self.get_activations(generated_dataset, reference_dataset)
|
||||
|
||||
real_nearest_neighbour_distances = self.__compute_nearest_neighbour_distances(reference_dataset, nearest_k)
|
||||
distance_real_fake = self.__compute_pairwise_distance(reference_dataset, generated_dataset)
|
||||
|
||||
if self.use_pr:
|
||||
fake_nearest_neighbour_distances = self.__compute_nearest_neighbour_distances(generated_dataset, nearest_k)
|
||||
precision = (
|
||||
distance_real_fake <= np.expand_dims(real_nearest_neighbour_distances, axis=1)
|
||||
).any(axis=0).mean()
|
||||
|
||||
recall = (
|
||||
distance_real_fake <= np.expand_dims(fake_nearest_neighbour_distances, axis=0)
|
||||
).any(axis=1).mean()
|
||||
|
||||
f1_pr = 2 / ((1 / (precision + 1e-8)) + (1 / (recall + 1e-8)))
|
||||
result = dict(precision=precision, recall=recall, f1_pr=f1_pr)
|
||||
else:
|
||||
density = (1. / float(nearest_k)) * (
|
||||
distance_real_fake <= np.expand_dims(real_nearest_neighbour_distances, axis=1)).sum(axis=0).mean()
|
||||
|
||||
coverage = (distance_real_fake.min(axis=1) <= real_nearest_neighbour_distances).mean()
|
||||
|
||||
f1_dc = 2 / ((1 / (density + 1e-8)) + (1 / (coverage + 1e-8)))
|
||||
result = dict(density=density, coverage=coverage, f1_dc=f1_dc)
|
||||
return result
|
||||
|
||||
def __compute_pairwise_distance(self, data_x, data_y=None):
|
||||
"""
|
||||
Args:
|
||||
data_x: numpy.ndarray([N, feature_dim], dtype=np.float32)
|
||||
data_y: numpy.ndarray([N, feature_dim], dtype=np.float32)
|
||||
Return:
|
||||
numpy.ndarray([N, N], dtype=np.float32) of pairwise distances.
|
||||
"""
|
||||
if data_y is None:
|
||||
data_y = data_x
|
||||
dists = sklearn.metrics.pairwise_distances(data_x, data_y, metric='euclidean', n_jobs=8)
|
||||
return dists
|
||||
|
||||
def __get_kth_value(self, unsorted, k, axis=-1):
|
||||
"""
|
||||
Args:
|
||||
unsorted: numpy.ndarray of any dimensionality.
|
||||
k: int
|
||||
Return:
|
||||
kth values along the designated axis.
|
||||
"""
|
||||
indices = np.argpartition(unsorted, k, axis=axis)[..., :k]
|
||||
k_smallest = np.take_along_axis(unsorted, indices, axis=axis)
|
||||
kth_values = k_smallest.max(axis=axis)
|
||||
return kth_values
|
||||
|
||||
def __compute_nearest_neighbour_distances(self, input_features, nearest_k):
|
||||
"""
|
||||
Args:
|
||||
input_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
|
||||
nearest_k: int
|
||||
Return:
|
||||
Distances to kth nearest neighbours.
|
||||
"""
|
||||
distances = self.__compute_pairwise_distance(input_features)
|
||||
radii = self.__get_kth_value(distances, k=nearest_k + 1, axis=-1)
|
||||
return radii
|
||||
|
||||
|
||||
def nn_based_eval(graph_ref_list, graph_pred_list, N_gin=10):
|
||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
evaluators = []
|
||||
for _ in range(N_gin):
|
||||
gin = load_feature_extractor(device)
|
||||
evaluators.append(MMDEvaluation(model=gin, kernel='rbf', sigma='range', multiplier='mean'))
|
||||
evaluators.append(prdcEvaluation(model=gin, use_pr=True))
|
||||
evaluators.append(prdcEvaluation(model=gin, use_pr=False))
|
||||
|
||||
ref_graphs = [dgl.from_networkx(g).to(device) for g in graph_ref_list]
|
||||
gen_graphs = [dgl.from_networkx(g).to(device) for g in graph_pred_list]
|
||||
|
||||
metrics = {
|
||||
'mmd_rbf': [],
|
||||
'f1_pr': [],
|
||||
'f1_dc': []
|
||||
}
|
||||
for evaluator in evaluators:
|
||||
res, time = evaluator.evaluate(generated_dataset=gen_graphs, reference_dataset=ref_graphs)
|
||||
for key in list(res.keys()):
|
||||
if key in metrics:
|
||||
metrics[key].append(res[key])
|
||||
|
||||
results = {
|
||||
'MMD_RBF': (np.mean(metrics['mmd_rbf']), np.std(metrics['mmd_rbf'])),
|
||||
'F1_PR': (np.mean(metrics['f1_pr']), np.std(metrics['f1_pr'])),
|
||||
'F1_DC': (np.mean(metrics['f1_dc']), np.std(metrics['f1_dc']))
|
||||
}
|
||||
return results
|
||||
209
MobileNetV3/evaluation/structure_evaluator.py
Normal file
209
MobileNetV3/evaluation/structure_evaluator.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""MMD Evaluation on graph structure statistics. Modified from https://github.com/uoguelph-mlrg/GGM-metrics"""
|
||||
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
# from scipy.linalg import toeplitz
|
||||
# import pyemd
|
||||
import concurrent.futures
|
||||
from scipy.linalg import eigvalsh
|
||||
from functools import partial
|
||||
|
||||
|
||||
class Descriptor():
|
||||
def __init__(self, is_parallel=False, bins=100, kernel='rbf', sigma_type='single', **kwargs):
|
||||
self.is_parallel = is_parallel
|
||||
self.bins = bins
|
||||
self.max_workers = kwargs.get('max_workers')
|
||||
|
||||
if kernel == 'rbf':
|
||||
self.distance = self.l2
|
||||
self.name += '_rbf'
|
||||
else:
|
||||
ValueError
|
||||
|
||||
if sigma_type == 'argmax':
|
||||
log_sigmas = np.linspace(-5., 5., 50)
|
||||
# the first 30 sigma values is usually enough
|
||||
log_sigmas = log_sigmas[:30]
|
||||
self.sigmas = [np.exp(log_sigma) for log_sigma in log_sigmas]
|
||||
elif sigma_type == 'single':
|
||||
self.sigmas = kwargs['sigma']
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def evaluate(self, graph_ref_list, graph_pred_list):
|
||||
"""Compute the distance between the distributions of two unordered sets of graphs.
|
||||
Args:
|
||||
graph_ref_list, graph_pred_list: two lists of networkx graphs to be evaluated.
|
||||
"""
|
||||
|
||||
graph_pred_list = [G for G in graph_pred_list if not G.number_of_nodes() == 0]
|
||||
|
||||
sample_pred = self.extract_features(graph_pred_list)
|
||||
sample_ref = self.extract_features(graph_ref_list)
|
||||
|
||||
GG = self.disc(sample_pred, sample_pred, distance_scaling=self.distance_scaling)
|
||||
GR = self.disc(sample_pred, sample_ref, distance_scaling=self.distance_scaling)
|
||||
RR = self.disc(sample_ref, sample_ref, distance_scaling=self.distance_scaling)
|
||||
|
||||
sigmas = self.sigmas
|
||||
max_mmd = 0
|
||||
mmd_dict = []
|
||||
for sigma in sigmas:
|
||||
gamma = 1 / (2 * sigma ** 2)
|
||||
|
||||
K_GR = np.exp(-gamma * GR)
|
||||
K_GG = np.exp(-gamma * GG)
|
||||
K_RR = np.exp(-gamma * RR)
|
||||
|
||||
mmd = K_GG.mean() + K_RR.mean() - (2 * K_GR.mean())
|
||||
mmd_dict.append((sigma, mmd))
|
||||
max_mmd = mmd if mmd > max_mmd else max_mmd
|
||||
|
||||
# print(self.name, mmd_dict)
|
||||
|
||||
return max_mmd
|
||||
|
||||
def pad_histogram(self, x, y):
|
||||
# convert histogram values x and y to float, and pad them for equal length
|
||||
support_size = max(len(x), len(y))
|
||||
x = x.astype(np.float)
|
||||
y = y.astype(np.float)
|
||||
if len(x) < len(y):
|
||||
x = np.hstack((x, [0.] * (support_size - len(x))))
|
||||
elif len(y) < len(x):
|
||||
y = np.hstack((y, [0.] * (support_size - len(y))))
|
||||
|
||||
return x, y
|
||||
|
||||
# def emd(self, x, y, distance_scaling=1.0):
|
||||
# support_size = max(len(x), len(y))
|
||||
# x, y = self.pad_histogram(x, y)
|
||||
#
|
||||
# d_mat = toeplitz(range(support_size)).astype(np.float)
|
||||
# distance_mat = d_mat / distance_scaling
|
||||
#
|
||||
# dist = pyemd.emd(x, y, distance_mat)
|
||||
# return dist ** 2
|
||||
|
||||
def l2(self, x, y, **kwargs):
|
||||
# gaussian rbf
|
||||
x, y = self.pad_histogram(x, y)
|
||||
dist = np.linalg.norm(x - y, 2)
|
||||
return dist ** 2
|
||||
|
||||
def kernel_parallel_unpacked(self, x, samples2, kernel):
|
||||
dist = []
|
||||
for s2 in samples2:
|
||||
dist += [kernel(x, s2)]
|
||||
return dist
|
||||
|
||||
def kernel_parallel_worker(self, t):
|
||||
return self.kernel_parallel_unpacked(*t)
|
||||
|
||||
def disc(self, samples1, samples2, **kwargs):
|
||||
# Discrepancy between 2 samples
|
||||
tot_dist = []
|
||||
if not self.is_parallel:
|
||||
for s1 in samples1:
|
||||
for s2 in samples2:
|
||||
tot_dist += [self.distance(s1, s2)]
|
||||
else:
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
for dist in executor.map(self.kernel_parallel_worker,
|
||||
[(s1, samples2, partial(self.distance, **kwargs)) for s1 in samples1]):
|
||||
tot_dist += [dist]
|
||||
return np.array(tot_dist)
|
||||
|
||||
|
||||
class degree(Descriptor):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.name = 'degree'
|
||||
self.sigmas = [kwargs.get('sigma', 1.0)]
|
||||
self.distance_scaling = 1.0
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def extract_features(self, dataset):
|
||||
res = []
|
||||
if self.is_parallel:
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
for deg_hist in executor.map(self.degree_worker, dataset):
|
||||
res.append(deg_hist)
|
||||
else:
|
||||
for g in dataset:
|
||||
degree_hist = self.degree_worker(g)
|
||||
res.append(degree_hist)
|
||||
|
||||
res = [s1 / np.sum(s1) for s1 in res]
|
||||
return res
|
||||
|
||||
def degree_worker(self, G):
|
||||
return np.array(nx.degree_histogram(G))
|
||||
|
||||
|
||||
class cluster(Descriptor):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.name = 'cluster'
|
||||
self.sigmas = [kwargs.get('sigma', [1.0 / 10])]
|
||||
super().__init__(*args, **kwargs)
|
||||
self.distance_scaling = self.bins
|
||||
|
||||
def extract_features(self, dataset):
|
||||
res = []
|
||||
if self.is_parallel:
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
for clustering_hist in executor.map(self.clustering_worker, [(G, self.bins) for G in dataset]):
|
||||
res.append(clustering_hist)
|
||||
else:
|
||||
for g in dataset:
|
||||
clustering_hist = self.clustering_worker((g, self.bins))
|
||||
res.append(clustering_hist)
|
||||
|
||||
res = [s1 / np.sum(s1) for s1 in res]
|
||||
return res
|
||||
|
||||
def clustering_worker(self, param):
|
||||
G, bins = param
|
||||
clustering_coeffs_list = list(nx.clustering(G).values())
|
||||
hist, _ = np.histogram(
|
||||
clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False)
|
||||
return hist
|
||||
|
||||
|
||||
class spectral(Descriptor):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.name = 'spectral'
|
||||
self.sigmas = [kwargs.get('sigma', 1.0)]
|
||||
self.distance_scaling = 1
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def extract_features(self, dataset):
|
||||
res = []
|
||||
if self.is_parallel:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
for spectral_density in executor.map(self.spectral_worker, dataset):
|
||||
res.append(spectral_density)
|
||||
else:
|
||||
for g in dataset:
|
||||
spectral_temp = self.spectral_worker(g)
|
||||
res.append(spectral_temp)
|
||||
return res
|
||||
|
||||
def spectral_worker(self, G):
|
||||
eigs = eigvalsh(nx.normalized_laplacian_matrix(G).todense())
|
||||
spectral_pmf, _ = np.histogram(eigs, bins=200, range=(-1e-5, 2), density=False)
|
||||
spectral_pmf = spectral_pmf / spectral_pmf.sum()
|
||||
return spectral_pmf
|
||||
|
||||
|
||||
def mmd_eval(graph_ref_list, graph_pred_list, methods):
|
||||
evaluators = []
|
||||
for (method, sigma, sigma_type) in methods:
|
||||
evaluators.append(eval(method)(sigma=sigma, sigma_type=sigma_type))
|
||||
|
||||
results = {}
|
||||
for evaluator in evaluators:
|
||||
results[evaluator.name] = evaluator.evaluate(graph_ref_list, graph_pred_list)
|
||||
|
||||
return results
|
||||
180
MobileNetV3/logger.py
Normal file
180
MobileNetV3/logger.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import os
|
||||
import wandb
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(
|
||||
self,
|
||||
exp_name,
|
||||
log_dir=None,
|
||||
exp_suffix="",
|
||||
write_textfile=True,
|
||||
use_wandb=False,
|
||||
wandb_project_name=None,
|
||||
entity='hysh',
|
||||
config=None
|
||||
):
|
||||
|
||||
self.log_dir = log_dir
|
||||
self.write_textfile = write_textfile
|
||||
self.use_wandb = use_wandb
|
||||
|
||||
self.logs_for_save = {}
|
||||
self.logs = {}
|
||||
|
||||
if self.write_textfile:
|
||||
self.f = open(os.path.join(log_dir, 'logs.txt'), 'w')
|
||||
|
||||
if self.use_wandb:
|
||||
exp_suffix = "_".join(exp_suffix.split("/")[:-1])
|
||||
wandb.init(
|
||||
config=config if config is not None else wandb.config,
|
||||
entity=entity,
|
||||
project=wandb_project_name,
|
||||
name=exp_name + "_" + exp_suffix,
|
||||
group=exp_name,
|
||||
reinit=True)
|
||||
|
||||
def write_str(self, log_str):
|
||||
self.f.write(log_str+'\n')
|
||||
self.f.flush()
|
||||
|
||||
def update_config(self, v, is_args=False):
|
||||
if is_args:
|
||||
self.logs_for_save.update({'args': v})
|
||||
else:
|
||||
self.logs_for_save.update(v)
|
||||
if self.use_wandb:
|
||||
wandb.config.update(v, allow_val_change=True)
|
||||
|
||||
def write_log_nohead(self, element, step):
|
||||
log_str = f"{step} | "
|
||||
log_dict = {}
|
||||
for key, val in element.items():
|
||||
if not key in self.logs_for_save:
|
||||
self.logs_for_save[key] = []
|
||||
self.logs_for_save[key].append(val)
|
||||
log_str += f'{key} {val} | '
|
||||
log_dict[f'{key}'] = val
|
||||
|
||||
if self.write_textfile:
|
||||
self.f.write(log_str+'\n')
|
||||
self.f.flush()
|
||||
|
||||
if self.use_wandb:
|
||||
wandb.log(log_dict, step=step)
|
||||
|
||||
def write_log(self, element, step, return_log_dict=False):
|
||||
log_str = f"{step} | "
|
||||
log_dict = {}
|
||||
for head, keys in element.items():
|
||||
for k in keys:
|
||||
if k in self.logs:
|
||||
v = self.logs[k].avg
|
||||
if not k in self.logs_for_save:
|
||||
self.logs_for_save[k] = []
|
||||
self.logs_for_save[k].append(v)
|
||||
log_str += f'{k} {v}| '
|
||||
log_dict[f'{head}/{k}'] = v
|
||||
|
||||
if self.write_textfile:
|
||||
self.f.write(log_str+'\n')
|
||||
self.f.flush()
|
||||
|
||||
if return_log_dict:
|
||||
return log_dict
|
||||
|
||||
if self.use_wandb:
|
||||
wandb.log(log_dict, step=step)
|
||||
|
||||
def log_sample(self, sample_x):
|
||||
wandb.log({"sampled_x": [wandb.Image(x.unsqueeze(-1).cpu().numpy()) for x in sample_x]})
|
||||
|
||||
def log_valid_sample_prop(self, arch_metric, x_axis, y_axis):
|
||||
assert x_axis in ['test_acc', 'flops', 'params', 'latency']
|
||||
assert y_axis in ['test_acc', 'flops', 'params', 'latency']
|
||||
|
||||
data = [[x, y] for (x, y) in zip(arch_metric[2][f'{x_axis}_list'], arch_metric[2][f'{y_axis}_list'])]
|
||||
table = wandb.Table(data=data, columns = [x_axis, y_axis])
|
||||
wandb.log({f"valid_sample ({x_axis}-{y_axis})" : wandb.plot.scatter(table, x_axis, y_axis)})
|
||||
|
||||
def save_log(self, name=None):
|
||||
name = 'logs.pt' if name is None else name
|
||||
torch.save(self.logs_for_save, os.path.join(self.log_dir, name))
|
||||
|
||||
def update(self, key, v, n=1):
|
||||
if not key in self.logs:
|
||||
self.logs[key] = AverageMeter()
|
||||
self.logs[key].update(v, n)
|
||||
|
||||
def reset(self, keys=None, except_keys=[]):
|
||||
if keys is not None:
|
||||
if isinstance(keys, list):
|
||||
for key in keys:
|
||||
self.logs[key] = AverageMeter()
|
||||
else:
|
||||
self.logs[keys] = AverageMeter()
|
||||
else:
|
||||
for key in self.logs.keys():
|
||||
if not key in except_keys:
|
||||
self.logs[key] = AverageMeter()
|
||||
|
||||
def avg(self, keys=None, except_keys=[]):
|
||||
if keys is not None:
|
||||
if isinstance(keys, list):
|
||||
return {key: self.logs[key].avg for key in keys if key in self.logs.keys()}
|
||||
else:
|
||||
return self.logs[keys].avg
|
||||
else:
|
||||
avg_dict = {}
|
||||
for key in self.logs.keys():
|
||||
if not key in except_keys:
|
||||
avg_dict[key] = self.logs[key].avg
|
||||
return avg_dict
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""
|
||||
Computes and stores the average and current value
|
||||
Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def get_metrics(g_embeds, x_embeds, logit_scale, prefix='train'):
|
||||
metrics = {}
|
||||
logits_per_g = (logit_scale * g_embeds @ x_embeds.t()).detach().cpu()
|
||||
logits_per_x = logits_per_g.t().detach().cpu()
|
||||
|
||||
logits = {"g_to_x": logits_per_g, "x_to_g": logits_per_x}
|
||||
ground_truth = torch.arange(len(x_embeds)).view(-1, 1)
|
||||
|
||||
for name, logit in logits.items():
|
||||
ranking = torch.argsort(logit, descending=True)
|
||||
preds = torch.where(ranking == ground_truth)[1]
|
||||
preds = preds.detach().cpu().numpy()
|
||||
metrics[f"{prefix}_{name}_mean_rank"] = preds.mean() + 1
|
||||
metrics[f"{prefix}_{name}_median_rank"] = np.floor(np.median(preds)) + 1
|
||||
for k in [1, 5, 10]:
|
||||
metrics[f"{prefix}_{name}_R@{k}"] = np.mean(preds < k)
|
||||
|
||||
return metrics
|
||||
584
MobileNetV3/losses.py
Normal file
584
MobileNetV3/losses.py
Normal file
@@ -0,0 +1,584 @@
|
||||
"""All functions related to loss computation and optimization."""
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
from models import utils as mutils
|
||||
from sde_lib import VPSDE, VESDE
|
||||
|
||||
|
||||
def get_optimizer(config, params):
|
||||
"""Return a flax optimizer object based on `config`."""
|
||||
if config.optim.optimizer == 'Adam':
|
||||
optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,
|
||||
weight_decay=config.optim.weight_decay)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Optimizer {config.optim.optimizer} not supported yet!'
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
||||
def optimization_manager(config):
|
||||
"""Return an optimize_fn based on `config`."""
|
||||
|
||||
def optimize_fn(optimizer, params, step, lr=config.optim.lr,
|
||||
warmup=config.optim.warmup,
|
||||
grad_clip=config.optim.grad_clip):
|
||||
"""Optimize with warmup and gradient clipping (disabled if negative)."""
|
||||
if warmup > 0:
|
||||
for g in optimizer.param_groups:
|
||||
g['lr'] = lr * np.minimum(step / warmup, 1.0)
|
||||
if grad_clip >= 0:
|
||||
torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
return optimize_fn
|
||||
|
||||
|
||||
def get_sde_loss_fn_nas(sde, train, reduce_mean=True, continuous=True, likelihood_weighting=True, eps=1e-5):
|
||||
"""Create a loss function for training with arbitrary SDEs.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
train: `True` for training loss and `False` for evaluation loss.
|
||||
reduce_mean: If `True`, average the loss across data dimensions. Otherwise, sum the loss across data dimensions.
|
||||
continuous: `True` indicates that the model is defined to take continuous time steps.
|
||||
Otherwise, it requires ad-hoc interpolation to take continuous time steps.
|
||||
likelihood_weighting: If `True`, weight the mixture of score matching losses according
|
||||
to https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended in Score SDE paper.
|
||||
eps: A `float` number. The smallest time step to sample from.
|
||||
|
||||
Returns:
|
||||
A loss function.
|
||||
"""
|
||||
|
||||
# reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)
|
||||
|
||||
def loss_fn(model, batch):
|
||||
"""Compute the loss function.
|
||||
|
||||
Args:
|
||||
model: A score model.
|
||||
batch: A mini-batch of training data, including adjacency matrices and mask.
|
||||
|
||||
Returns:
|
||||
loss: A scalar that represents the average loss value across the mini-batch.
|
||||
"""
|
||||
x, adj, mask = batch
|
||||
# adj, mask: [32, 1, 20, 20]
|
||||
score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)
|
||||
t = torch.rand(x.shape[0], device=adj.device) * (sde.T - eps) + eps
|
||||
|
||||
z = torch.randn_like(x) # [B, C, N, N]
|
||||
# z = torch.tril(z, -1)
|
||||
# z = z + z.transpose(2, 3)
|
||||
|
||||
mean, std = sde.marginal_prob(x, t)
|
||||
# mean = torch.tril(mean, -1)
|
||||
# mean = mean + mean.transpose(2, 3)
|
||||
|
||||
perturbed_data = mean + std[:, None, None] * z
|
||||
score = score_fn(perturbed_data, t, mask)
|
||||
|
||||
# mask = torch.tril(mask, -1)
|
||||
# mask = mask + mask.transpose(2, 3)
|
||||
# mask = mask.reshape(mask.shape[0], -1) # low triangular part of adj matrices
|
||||
|
||||
if not likelihood_weighting:
|
||||
losses = torch.square(score * std[:, None, None] + z)
|
||||
losses = losses.reshape(losses.shape[0], -1)
|
||||
if reduce_mean:
|
||||
# losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
|
||||
losses = torch.mean(losses, dim=-1)
|
||||
else:
|
||||
losses = 0.5 * torch.sum(losses, dim=-1)
|
||||
loss = losses.mean()
|
||||
else:
|
||||
g2 = sde.sde(torch.zeros_like(x), t)[1] ** 2
|
||||
losses = torch.square(score + z / std[:, None, None])
|
||||
losses = losses.reshape(losses.shape[0], -1)
|
||||
if reduce_mean:
|
||||
# losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
|
||||
losses = torch.mean(losses, dim=-1)
|
||||
else:
|
||||
losses = 0.5 * torch.sum(losses, dim=-1)
|
||||
loss = (losses * g2).mean()
|
||||
|
||||
return loss
|
||||
|
||||
return loss_fn
|
||||
|
||||
|
||||
def get_predictor_loss_fn_nas_binary(sde, train, reduce_mean=True, continuous=True,
|
||||
likelihood_weighting=True, eps=1e-5, label_list=None,
|
||||
noised=True, t_spot=None):
|
||||
"""Create a loss function for training with arbitrary SDEs.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
train: `True` for training loss and `False` for evaluation loss.
|
||||
reduce_mean: If `True`, average the loss across data dimensions. Otherwise, sum the loss across data dimensions.
|
||||
continuous: `True` indicates that the model is defined to take continuous time steps.
|
||||
Otherwise, it requires ad-hoc interpolation to take continuous time steps.
|
||||
likelihood_weighting: If `True`, weight the mixture of score matching losses according
|
||||
to https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended in Score SDE paper.
|
||||
eps: A `float` number. The smallest time step to sample from.
|
||||
|
||||
Returns:
|
||||
A loss function.
|
||||
"""
|
||||
|
||||
# reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)
|
||||
|
||||
def loss_fn(model, batch):
|
||||
"""Compute the loss function.
|
||||
|
||||
Args:
|
||||
model: A score model.
|
||||
batch: A mini-batch of training data, including adjacency matrices and mask.
|
||||
|
||||
Returns:
|
||||
loss: A scalar that represents the average loss value across the mini-batch.
|
||||
"""
|
||||
x, adj, mask, extra = batch
|
||||
# adj, mask: [32, 1, 20, 20]
|
||||
# score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)
|
||||
predictor_fn = mutils.get_predictor_fn(sde, model, train=train, continuous=continuous)
|
||||
if noised:
|
||||
if t_spot < 1:
|
||||
t = torch.rand(x.shape[0], device=adj.device) * (t_spot - eps) + eps # torch.rand: [0, 1)
|
||||
else:
|
||||
t = torch.rand(x.shape[0], device=adj.device) * (sde.T - eps) + eps
|
||||
|
||||
z = torch.randn_like(x) # [B, C, N, N]
|
||||
# z = torch.tril(z, -1)
|
||||
# z = z + z.transpose(2, 3)
|
||||
|
||||
mean, std = sde.marginal_prob(x, t)
|
||||
# mean = torch.tril(mean, -1)
|
||||
# mean = mean + mean.transpose(2, 3)
|
||||
|
||||
perturbed_data = mean + std[:, None, None] * z
|
||||
# score = score_fn(perturbed_data, t, mask)
|
||||
pred = predictor_fn(perturbed_data, t, mask)
|
||||
else:
|
||||
t = eps * torch.ones(x.shape[0], device=adj.device)
|
||||
pred = predictor_fn(x, t, mask)
|
||||
|
||||
labels = extra[f"{label_list}"][1]
|
||||
labels = labels.to(pred.device).unsqueeze(1).type(pred.dtype)
|
||||
# mask = torch.tril(mask, -1)
|
||||
# mask = mask + mask.transpose(2, 3)
|
||||
# mask = mask.reshape(mask.shape[0], -1) # low triangular part of adj matrices
|
||||
# loss = torch.nn.MSELoss()(pred, labels)
|
||||
loss = torch.nn.BCEWithLogitsLoss()(pred, labels)
|
||||
|
||||
# if not likelihood_weighting:
|
||||
# losses = torch.square(score * std[:, None, None] + z)
|
||||
# losses = losses.reshape(losses.shape[0], -1)
|
||||
# if reduce_mean:
|
||||
# # losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
|
||||
# losses = torch.mean(losses, dim=-1)
|
||||
# else:
|
||||
# losses = 0.5 * torch.sum(losses, dim=-1)
|
||||
# loss = losses.mean()
|
||||
# else:
|
||||
# g2 = sde.sde(torch.zeros_like(x), t)[1] ** 2
|
||||
# losses = torch.square(score + z / std[:, None, None])
|
||||
# losses = losses.reshape(losses.shape[0], -1)
|
||||
# if reduce_mean:
|
||||
# # losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
|
||||
# losses = torch.mean(losses, dim=-1)
|
||||
# else:
|
||||
# losses = 0.5 * torch.sum(losses, dim=-1)
|
||||
# loss = (losses * g2).mean()
|
||||
|
||||
return loss, pred, labels
|
||||
|
||||
return loss_fn
|
||||
|
||||
|
||||
|
||||
def get_predictor_loss_fn_nas(sde, train, reduce_mean=True, continuous=True,
|
||||
likelihood_weighting=True, eps=1e-5, label_list=None,
|
||||
noised=True, t_spot=None):
|
||||
"""Create a loss function for training with arbitrary SDEs.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
train: `True` for training loss and `False` for evaluation loss.
|
||||
reduce_mean: If `True`, average the loss across data dimensions. Otherwise, sum the loss across data dimensions.
|
||||
continuous: `True` indicates that the model is defined to take continuous time steps.
|
||||
Otherwise, it requires ad-hoc interpolation to take continuous time steps.
|
||||
likelihood_weighting: If `True`, weight the mixture of score matching losses according
|
||||
to https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended in Score SDE paper.
|
||||
eps: A `float` number. The smallest time step to sample from.
|
||||
|
||||
Returns:
|
||||
A loss function.
|
||||
"""
|
||||
|
||||
# reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)
|
||||
|
||||
def loss_fn(model, batch):
|
||||
"""Compute the loss function.
|
||||
|
||||
Args:
|
||||
model: A score model.
|
||||
batch: A mini-batch of training data, including adjacency matrices and mask.
|
||||
|
||||
Returns:
|
||||
loss: A scalar that represents the average loss value across the mini-batch.
|
||||
"""
|
||||
x, adj, mask, extra = batch
|
||||
# adj, mask: [32, 1, 20, 20]
|
||||
# score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)
|
||||
predictor_fn = mutils.get_predictor_fn(sde, model, train=train, continuous=continuous)
|
||||
if noised:
|
||||
if t_spot < 1:
|
||||
t = torch.rand(x.shape[0], device=adj.device) * (t_spot - eps) + eps # torch.rand: [0, 1)
|
||||
else:
|
||||
t = torch.rand(x.shape[0], device=adj.device) * (sde.T - eps) + eps
|
||||
|
||||
z = torch.randn_like(x) # [B, C, N, N]
|
||||
# z = torch.tril(z, -1)
|
||||
# z = z + z.transpose(2, 3)
|
||||
|
||||
mean, std = sde.marginal_prob(x, t)
|
||||
# mean = torch.tril(mean, -1)
|
||||
# mean = mean + mean.transpose(2, 3)
|
||||
|
||||
perturbed_data = mean + std[:, None, None] * z
|
||||
# score = score_fn(perturbed_data, t, mask)
|
||||
pred = predictor_fn(perturbed_data, t, mask)
|
||||
else:
|
||||
t = eps * torch.ones(x.shape[0], device=adj.device)
|
||||
pred = predictor_fn(x, t, mask)
|
||||
|
||||
labels = extra[f"{label_list[-1]}"]
|
||||
labels = labels.to(pred.device).unsqueeze(1).type(pred.dtype)
|
||||
# mask = torch.tril(mask, -1)
|
||||
# mask = mask + mask.transpose(2, 3)
|
||||
# mask = mask.reshape(mask.shape[0], -1) # low triangular part of adj matrices
|
||||
loss = torch.nn.MSELoss()(pred, labels)
|
||||
|
||||
# if not likelihood_weighting:
|
||||
# losses = torch.square(score * std[:, None, None] + z)
|
||||
# losses = losses.reshape(losses.shape[0], -1)
|
||||
# if reduce_mean:
|
||||
# # losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
|
||||
# losses = torch.mean(losses, dim=-1)
|
||||
# else:
|
||||
# losses = 0.5 * torch.sum(losses, dim=-1)
|
||||
# loss = losses.mean()
|
||||
# else:
|
||||
# g2 = sde.sde(torch.zeros_like(x), t)[1] ** 2
|
||||
# losses = torch.square(score + z / std[:, None, None])
|
||||
# losses = losses.reshape(losses.shape[0], -1)
|
||||
# if reduce_mean:
|
||||
# # losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
|
||||
# losses = torch.mean(losses, dim=-1)
|
||||
# else:
|
||||
# losses = 0.5 * torch.sum(losses, dim=-1)
|
||||
# loss = (losses * g2).mean()
|
||||
|
||||
return loss, pred, labels
|
||||
|
||||
return loss_fn
|
||||
|
||||
|
||||
def get_meta_predictor_loss_fn_nas(sde, train, reduce_mean=True, continuous=True,
|
||||
likelihood_weighting=True, eps=1e-5, label_list=None,
|
||||
noised=True, t_spot=None):
|
||||
"""Create a loss function for training with arbitrary SDEs.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
train: `True` for training loss and `False` for evaluation loss.
|
||||
reduce_mean: If `True`, average the loss across data dimensions. Otherwise, sum the loss across data dimensions.
|
||||
continuous: `True` indicates that the model is defined to take continuous time steps.
|
||||
Otherwise, it requires ad-hoc interpolation to take continuous time steps.
|
||||
likelihood_weighting: If `True`, weight the mixture of score matching losses according
|
||||
to https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended in Score SDE paper.
|
||||
eps: A `float` number. The smallest time step to sample from.
|
||||
|
||||
Returns:
|
||||
A loss function.
|
||||
"""
|
||||
|
||||
# reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)
|
||||
|
||||
def loss_fn(model, batch):
|
||||
"""Compute the loss function.
|
||||
|
||||
Args:
|
||||
model: A score model.
|
||||
batch: A mini-batch of training data, including adjacency matrices and mask.
|
||||
|
||||
Returns:
|
||||
loss: A scalar that represents the average loss value across the mini-batch.
|
||||
"""
|
||||
x, adj, mask, extra, task = batch
|
||||
predictor_fn = mutils.get_predictor_fn(sde, model, train=train, continuous=continuous)
|
||||
if noised:
|
||||
if t_spot < 1:
|
||||
t = torch.rand(x.shape[0], device=adj.device) * (t_spot - eps) + eps # torch.rand: [0, 1)
|
||||
else:
|
||||
t = torch.rand(x.shape[0], device=adj.device) * (sde.T - eps) + eps
|
||||
|
||||
z = torch.randn_like(x) # [B, C, N, N]
|
||||
|
||||
mean, std = sde.marginal_prob(x, t)
|
||||
|
||||
perturbed_data = mean + std[:, None, None] * z
|
||||
# score = score_fn(perturbed_data, t, mask)
|
||||
pred = predictor_fn(perturbed_data, t, mask, task)
|
||||
else:
|
||||
t = eps * torch.ones(x.shape[0], device=adj.device)
|
||||
pred = predictor_fn(x, t, mask, task)
|
||||
labels = extra[f"{label_list[-1]}"]
|
||||
labels = labels.to(pred.device).unsqueeze(1).type(pred.dtype)
|
||||
|
||||
loss = torch.nn.MSELoss()(pred, labels)
|
||||
|
||||
return loss, pred, labels
|
||||
|
||||
return loss_fn
|
||||
|
||||
|
||||
def get_sde_loss_fn(sde, train, reduce_mean=True, continuous=True, likelihood_weighting=True, eps=1e-5):
|
||||
"""Create a loss function for training with arbitrary SDEs.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
train: `True` for training loss and `False` for evaluation loss.
|
||||
reduce_mean: If `True`, average the loss across data dimensions. Otherwise, sum the loss across data dimensions.
|
||||
continuous: `True` indicates that the model is defined to take continuous time steps.
|
||||
Otherwise, it requires ad-hoc interpolation to take continuous time steps.
|
||||
likelihood_weighting: If `True`, weight the mixture of score matching losses according
|
||||
to https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended in Score SDE paper.
|
||||
eps: A `float` number. The smallest time step to sample from.
|
||||
|
||||
Returns:
|
||||
A loss function.
|
||||
"""
|
||||
|
||||
# reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)
|
||||
|
||||
def loss_fn(model, batch):
|
||||
"""Compute the loss function.
|
||||
|
||||
Args:
|
||||
model: A score model.
|
||||
batch: A mini-batch of training data, including adjacency matrices and mask.
|
||||
|
||||
Returns:
|
||||
loss: A scalar that represents the average loss value across the mini-batch.
|
||||
"""
|
||||
adj, mask = batch
|
||||
# adj, mask: [32, 1, 20, 20]
|
||||
score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)
|
||||
t = torch.rand(adj.shape[0], device=adj.device) * (sde.T - eps) + eps
|
||||
|
||||
z = torch.randn_like(adj) # [B, C, N, N]
|
||||
z = torch.tril(z, -1)
|
||||
z = z + z.transpose(2, 3)
|
||||
|
||||
mean, std = sde.marginal_prob(adj, t)
|
||||
mean = torch.tril(mean, -1)
|
||||
mean = mean + mean.transpose(2, 3)
|
||||
|
||||
perturbed_data = mean + std[:, None, None, None] * z
|
||||
score = score_fn(perturbed_data, t, mask=mask)
|
||||
|
||||
mask = torch.tril(mask, -1)
|
||||
mask = mask + mask.transpose(2, 3)
|
||||
mask = mask.reshape(mask.shape[0], -1) # low triangular part of adj matrices
|
||||
|
||||
if not likelihood_weighting:
|
||||
losses = torch.square(score * std[:, None, None, None] + z)
|
||||
losses = losses.reshape(losses.shape[0], -1)
|
||||
if reduce_mean:
|
||||
losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
|
||||
else:
|
||||
losses = 0.5 * torch.sum(losses * mask, dim=-1)
|
||||
loss = losses.mean()
|
||||
else:
|
||||
g2 = sde.sde(torch.zeros_like(adj), t)[1] ** 2
|
||||
losses = torch.square(score + z / std[:, None, None, None])
|
||||
losses = losses.reshape(losses.shape[0], -1)
|
||||
if reduce_mean:
|
||||
losses = torch.sum(losses * mask, dim=-1) / torch.sum(mask, dim=-1)
|
||||
else:
|
||||
losses = 0.5 * torch.sum(losses * mask, dim=-1)
|
||||
loss = (losses * g2).mean()
|
||||
|
||||
return loss
|
||||
|
||||
return loss_fn
|
||||
|
||||
|
||||
def get_step_fn(sde, train, optimize_fn=None, reduce_mean=False, continuous=True,
|
||||
likelihood_weighting=False, data='NASBench201'):
|
||||
"""Create a one-step training/evaluation function.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
Tuple (`sde_lib.SDE`, `sde_lib.SDE`) that represents the forward node SDE and edge SDE.
|
||||
optimize_fn: An optimization function.
|
||||
reduce_mean: If `True`, average the loss across data dimensions.
|
||||
Otherwise, sum the loss across data dimensions.
|
||||
continuous: `True` indicates that the model is defined to take continuous time steps.
|
||||
likelihood_weighting: If `True`, weight the mixture of score matching losses according to
|
||||
https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended by score-sde.
|
||||
|
||||
Returns:
|
||||
A one-step function for training or evaluation.
|
||||
"""
|
||||
|
||||
if continuous:
|
||||
if isinstance(sde, tuple):
|
||||
loss_fn = get_multi_sde_loss_fn(sde[0], sde[1], train, reduce_mean=reduce_mean, continuous=True,
|
||||
likelihood_weighting=likelihood_weighting)
|
||||
else:
|
||||
if data in ['NASBench201', 'ofa']:
|
||||
loss_fn = get_sde_loss_fn_nas(sde, train, reduce_mean=reduce_mean,
|
||||
continuous=True, likelihood_weighting=likelihood_weighting)
|
||||
else:
|
||||
loss_fn = get_sde_loss_fn(sde, train, reduce_mean=reduce_mean,
|
||||
continuous=True, likelihood_weighting=likelihood_weighting)
|
||||
else:
|
||||
assert not likelihood_weighting, "Likelihood weighting is not supported for original SMLD/DDPM training."
|
||||
if isinstance(sde, VESDE):
|
||||
loss_fn = get_smld_loss_fn(sde, train, reduce_mean=reduce_mean)
|
||||
elif isinstance(sde, VPSDE):
|
||||
loss_fn = get_ddpm_loss_fn(sde, train, reduce_mean=reduce_mean)
|
||||
elif isinstance(sde, tuple):
|
||||
raise ValueError("Discrete training for multi sde is not recommended.")
|
||||
else:
|
||||
raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.")
|
||||
|
||||
def step_fn(state, batch):
|
||||
"""Running one step of training or evaluation.
|
||||
|
||||
For jax version: This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and
|
||||
jit-compiled together for faster execution.
|
||||
|
||||
Args:
|
||||
state: A dictionary of training information, containing the score model, optimizer,
|
||||
EMA status, and number of optimization steps.
|
||||
batch: A mini-batch of training/evaluation data, including min-batch adjacency matrices and mask.
|
||||
|
||||
Returns:
|
||||
loss: The average loss value of this state.
|
||||
"""
|
||||
model = state['model']
|
||||
if train:
|
||||
optimizer = state['optimizer']
|
||||
optimizer.zero_grad()
|
||||
loss = loss_fn(model, batch)
|
||||
loss.backward()
|
||||
optimize_fn(optimizer, model.parameters(), step=state['step'])
|
||||
state['step'] += 1
|
||||
state['ema'].update(model.parameters())
|
||||
else:
|
||||
with torch.no_grad():
|
||||
ema = state['ema']
|
||||
ema.store(model.parameters())
|
||||
ema.copy_to(model.parameters())
|
||||
loss = loss_fn(model, batch)
|
||||
ema.restore(model.parameters())
|
||||
|
||||
return loss
|
||||
|
||||
return step_fn
|
||||
|
||||
|
||||
def get_step_fn_predictor(sde, train, optimize_fn=None, reduce_mean=False, continuous=True,
|
||||
likelihood_weighting=False, data='NASBench201', label_list=None, noised=True,
|
||||
t_spot=None, is_meta=False, is_binary=False):
|
||||
"""Create a one-step training/evaluation function.
|
||||
|
||||
Args:
|
||||
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
||||
Tuple (`sde_lib.SDE`, `sde_lib.SDE`) that represents the forward node SDE and edge SDE.
|
||||
optimize_fn: An optimization function.
|
||||
reduce_mean: If `True`, average the loss across data dimensions.
|
||||
Otherwise, sum the loss across data dimensions.
|
||||
continuous: `True` indicates that the model is defined to take continuous time steps.
|
||||
likelihood_weighting: If `True`, weight the mixture of score matching losses according to
|
||||
https://arxiv.org/abs/2101.09258; otherwise, use the weighting recommended by score-sde.
|
||||
|
||||
Returns:
|
||||
A one-step function for training or evaluation.
|
||||
"""
|
||||
|
||||
if continuous:
|
||||
if isinstance(sde, tuple):
|
||||
loss_fn = get_multi_sde_loss_fn(sde[0], sde[1], train, reduce_mean=reduce_mean, continuous=True,
|
||||
likelihood_weighting=likelihood_weighting)
|
||||
else:
|
||||
if data in ['NASBench201', 'ofa']:
|
||||
if is_meta:
|
||||
loss_fn = get_meta_predictor_loss_fn_nas(sde, train, reduce_mean=reduce_mean,
|
||||
continuous=True, likelihood_weighting=likelihood_weighting,
|
||||
label_list=label_list, noised=noised, t_spot=t_spot)
|
||||
elif is_binary:
|
||||
loss_fn = get_predictor_loss_fn_nas_binary(sde, train, reduce_mean=reduce_mean,
|
||||
continuous=True, likelihood_weighting=likelihood_weighting,
|
||||
label_list=label_list, noised=noised, t_spot=t_spot)
|
||||
else:
|
||||
loss_fn = get_predictor_loss_fn_nas(sde, train, reduce_mean=reduce_mean,
|
||||
continuous=True, likelihood_weighting=likelihood_weighting,
|
||||
label_list=label_list, noised=noised, t_spot=t_spot)
|
||||
else:
|
||||
loss_fn = get_sde_loss_fn(sde, train, reduce_mean=reduce_mean,
|
||||
continuous=True, likelihood_weighting=likelihood_weighting)
|
||||
else:
|
||||
assert not likelihood_weighting, "Likelihood weighting is not supported for original SMLD/DDPM training."
|
||||
if isinstance(sde, VESDE):
|
||||
loss_fn = get_smld_loss_fn(sde, train, reduce_mean=reduce_mean)
|
||||
elif isinstance(sde, VPSDE):
|
||||
loss_fn = get_ddpm_loss_fn(sde, train, reduce_mean=reduce_mean)
|
||||
elif isinstance(sde, tuple):
|
||||
raise ValueError("Discrete training for multi sde is not recommended.")
|
||||
else:
|
||||
raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.")
|
||||
|
||||
def step_fn(state, batch):
|
||||
"""Running one step of training or evaluation.
|
||||
|
||||
For jax version: This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and
|
||||
jit-compiled together for faster execution.
|
||||
|
||||
Args:
|
||||
state: A dictionary of training information, containing the score model, optimizer,
|
||||
EMA status, and number of optimization steps.
|
||||
batch: A mini-batch of training/evaluation data, including min-batch adjacency matrices and mask.
|
||||
|
||||
Returns:
|
||||
loss: The average loss value of this state.
|
||||
"""
|
||||
model = state['model']
|
||||
if train:
|
||||
model.train()
|
||||
optimizer = state['optimizer']
|
||||
optimizer.zero_grad()
|
||||
loss, pred, labels = loss_fn(model, batch)
|
||||
loss.backward()
|
||||
optimize_fn(optimizer, model.parameters(), step=state['step'])
|
||||
state['step'] += 1
|
||||
# state['ema'].update(model.parameters())
|
||||
else:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# ema = state['ema']
|
||||
# ema.store(model.parameters())
|
||||
# ema.copy_to(model.parameters())
|
||||
loss, pred, labels = loss_fn(model, batch)
|
||||
# ema.restore(model.parameters())
|
||||
|
||||
return loss, pred, labels
|
||||
|
||||
return step_fn
|
||||
40
MobileNetV3/main.py
Normal file
40
MobileNetV3/main.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Training and evaluation"""
|
||||
|
||||
import run_lib
|
||||
from absl import app, flags
|
||||
from ml_collections.config_flags import config_flags
|
||||
import logging
|
||||
import os
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
config_flags.DEFINE_config_file(
|
||||
'config', None, 'Training configuration.', lock_config=True
|
||||
)
|
||||
config_flags.DEFINE_config_file(
|
||||
'classifier_config_nf', None, 'Training configuration.', lock_config=True
|
||||
)
|
||||
flags.DEFINE_string('workdir', None, 'Work directory.')
|
||||
flags.DEFINE_enum('mode', None, ['train', 'eval'],
|
||||
'Running mode: train or eval')
|
||||
flags.DEFINE_string('eval_folder', 'eval', 'The folder name for storing evaluation results')
|
||||
flags.mark_flags_as_required(['config', 'mode'])
|
||||
|
||||
|
||||
def main(argv):
|
||||
# Set random seed
|
||||
run_lib.set_random_seed(FLAGS.config)
|
||||
|
||||
if FLAGS.mode == 'train':
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel('INFO')
|
||||
# Run the training pipeline
|
||||
run_lib.train(FLAGS.config)
|
||||
elif FLAGS.mode == 'eval':
|
||||
run_lib.evaluate(FLAGS.config)
|
||||
else:
|
||||
raise ValueError(f"Mode {FLAGS.mode} not recognized.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
329
MobileNetV3/main_exp/diffusion/run_lib.py
Normal file
329
MobileNetV3/main_exp/diffusion/run_lib.py
Normal file
@@ -0,0 +1,329 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import sys
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
from torch.utils.data import DataLoader
|
||||
sys.path.append('.')
|
||||
import sampling
|
||||
|
||||
import datasets_nas
|
||||
from models import pgsn
|
||||
from models import digcn
|
||||
from models import cate
|
||||
from models import dagformer
|
||||
from models import digcn
|
||||
from models import digcn_meta
|
||||
from models import regressor
|
||||
from models.GDSS import scorenetx
|
||||
from models import utils as mutils
|
||||
from models.ema import ExponentialMovingAverage
|
||||
import sde_lib
|
||||
from utils import *
|
||||
import losses
|
||||
|
||||
from analysis.arch_functions import BasicArchMetricsOFA
|
||||
import losses
|
||||
from analysis.arch_functions import NUM_STAGE, MAX_LAYER_PER_STAGE
|
||||
from all_path import *
|
||||
|
||||
|
||||
def get_sampling_fn(config, p=1, prod_w=False, weight_ratio_abs=False):
|
||||
# Setup SDEs
|
||||
if config.training.sde.lower() == 'vpsde':
|
||||
sde = sde_lib.VPSDE(
|
||||
beta_min=config.model.beta_min,
|
||||
beta_max=config.model.beta_max,
|
||||
N=config.model.num_scales)
|
||||
sampling_eps = 1e-3
|
||||
elif config.training.sde.lower() == 'subvpsde':
|
||||
sde = sde_lib.subVPSDE(
|
||||
beta_min=config.model.beta_min,
|
||||
beta_max=config.model.beta_max,
|
||||
N=config.model.num_scales)
|
||||
sampling_eps = 1e-3
|
||||
elif config.training.sde.lower() == 'vesde':
|
||||
sde = sde_lib.VESDE(
|
||||
sigma_min=config.model.sigma_min,
|
||||
sigma_max=config.model.sigma_max,
|
||||
N=config.model.num_scales)
|
||||
sampling_eps = 1e-5
|
||||
else:
|
||||
raise NotImplementedError(f"SDE {config.training.sde} unknown.")
|
||||
|
||||
# create data normalizer and its inverse
|
||||
inverse_scaler = datasets_nas.get_data_inverse_scaler(config)
|
||||
|
||||
sampling_shape = (
|
||||
config.eval.batch_size, config.data.max_node, config.data.n_vocab) # ofa: 1024, 20, 28
|
||||
sampling_fn = sampling.get_sampling_fn(
|
||||
config, sde, sampling_shape, inverse_scaler,
|
||||
sampling_eps, config.data.name, conditional=True,
|
||||
p=p, prod_w=prod_w, weight_ratio_abs=weight_ratio_abs)
|
||||
|
||||
return sampling_fn, sde
|
||||
|
||||
|
||||
def get_sampling_fn_meta(config, p=1, prod_w=False, weight_ratio_abs=False, init=False, n_init=5):
|
||||
# Setup SDEs
|
||||
if config.training.sde.lower() == 'vpsde':
|
||||
sde = sde_lib.VPSDE(
|
||||
beta_min=config.model.beta_min,
|
||||
beta_max=config.model.beta_max,
|
||||
N=config.model.num_scales)
|
||||
sampling_eps = 1e-3
|
||||
elif config.training.sde.lower() == 'subvpsde':
|
||||
sde = sde_lib.subVPSDE(
|
||||
beta_min=config.model.beta_min,
|
||||
beta_max=config.model.beta_max,
|
||||
N=config.model.num_scales)
|
||||
sampling_eps = 1e-3
|
||||
elif config.training.sde.lower() == 'vesde':
|
||||
sde = sde_lib.VESDE(
|
||||
sigma_min=config.model.sigma_min,
|
||||
sigma_max=config.model.sigma_max,
|
||||
N=config.model.num_scales)
|
||||
sampling_eps = 1e-5
|
||||
else:
|
||||
raise NotImplementedError(f"SDE {config.training.sde} unknown.")
|
||||
|
||||
# create data normalizer and its inverse
|
||||
inverse_scaler = datasets_nas.get_data_inverse_scaler(config)
|
||||
|
||||
if init:
|
||||
sampling_shape = (
|
||||
n_init, config.data.max_node, config.data.n_vocab)
|
||||
else:
|
||||
sampling_shape = (
|
||||
config.eval.batch_size, config.data.max_node, config.data.n_vocab) # ofa: 1024, 20, 28
|
||||
sampling_fn = sampling.get_sampling_fn(
|
||||
config, sde, sampling_shape, inverse_scaler,
|
||||
sampling_eps, config.data.name, conditional=True,
|
||||
is_meta=True, data_name=config.sampling.check_dataname,
|
||||
num_sample=config.model.num_sample)
|
||||
|
||||
return sampling_fn, sde
|
||||
|
||||
|
||||
def get_score_model(config, pos_enc_type=2):
|
||||
# Build sampling functions and Load pre-trained score network & predictor network
|
||||
score_config = torch.load(config.scorenet_ckpt_path)['config']
|
||||
ckpt_path = config.scorenet_ckpt_path
|
||||
score_config.sampling.corrector = 'langevin'
|
||||
score_config.model.pos_enc_type = pos_enc_type
|
||||
|
||||
score_model = mutils.create_model(score_config)
|
||||
score_ema = ExponentialMovingAverage(
|
||||
score_model.parameters(), decay=score_config.model.ema_rate)
|
||||
score_state = dict(
|
||||
model=score_model, ema=score_ema, step=0, config=score_config)
|
||||
score_state = restore_checkpoint(
|
||||
ckpt_path, score_state,
|
||||
device=config.device, resume=True)
|
||||
score_ema.copy_to(score_model.parameters())
|
||||
return score_model, score_ema, score_config
|
||||
|
||||
|
||||
def get_predictor(config):
|
||||
classifier_model = mutils.create_model(config)
|
||||
|
||||
return classifier_model
|
||||
|
||||
|
||||
def get_adj(data_name, except_inout):
|
||||
if data_name == 'NASBench201':
|
||||
_adj = np.asarray(
|
||||
[[0, 1, 1, 1, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 1, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0]]
|
||||
)
|
||||
_adj = torch.tensor(_adj, dtype=torch.float32, device=torch.device('cpu'))
|
||||
if except_inout:
|
||||
_adj = _adj[1:-1, 1:-1]
|
||||
elif data_name == 'ofa':
|
||||
assert except_inout
|
||||
num_nodes = NUM_STAGE * MAX_LAYER_PER_STAGE
|
||||
_adj = torch.zeros(num_nodes, num_nodes)
|
||||
for i in range(num_nodes-1):
|
||||
_adj[i, i+1] = 1
|
||||
return _adj
|
||||
return _adj
|
||||
|
||||
def generate_archs(
|
||||
config, sampling_fn, score_model, score_ema, classifier_model,
|
||||
num_samples, patient_factor, batch_size=512, classifier_scale=None,
|
||||
task=None):
|
||||
|
||||
metrics = BasicArchMetricsOFA()
|
||||
# algo = 'none'
|
||||
adj_s = get_adj(config.data.name, config.data.except_inout)
|
||||
mask_s = aug_mask(adj_s, algo=config.data.aug_mask_algo)[0]
|
||||
adj_c = get_adj(config.data.name, config.data.except_inout)
|
||||
mask_c = aug_mask(adj_c, algo=config.data.aug_mask_algo)[0]
|
||||
assert (adj_s == adj_c).all() and (mask_s == mask_c).all()
|
||||
adj_s, mask_s, adj_c, mask_c = \
|
||||
adj_s.to(config.device), mask_s.to(config.device), adj_c.to(config.device), mask_c.to(config.device)
|
||||
|
||||
# Generate and save samples
|
||||
score_ema.copy_to(score_model.parameters())
|
||||
if num_samples > batch_size:
|
||||
num_sampling_rounds = int(np.ceil(num_samples / batch_size) * patient_factor)
|
||||
else:
|
||||
num_sampling_rounds = int(patient_factor)
|
||||
print(f'==> Sampling for {num_sampling_rounds} rounds...')
|
||||
|
||||
r = 0
|
||||
all_samples = []
|
||||
classifier_scales = list(range(100000, 0, -int(classifier_scale)))
|
||||
|
||||
while True and r < num_sampling_rounds:
|
||||
classifier_scale = classifier_scales[r]
|
||||
print(f'==> round {r} classifier_scale {classifier_scale}')
|
||||
sample, _, sample_chain, (score_grad_norm_p, classifier_grad_norm_p, score_grad_norm_c, classifier_grad_norm_c) \
|
||||
= sampling_fn(score_model, mask_s, classifier_model,
|
||||
eval_chain=True,
|
||||
number_chain_steps=config.sampling.number_chain_steps,
|
||||
classifier_scale=classifier_scale,
|
||||
task=task, sample_bs=num_samples)
|
||||
try:
|
||||
sample_list = quantize(sample, adj_s) # quantization
|
||||
_, validity, valid_arch_str, _, _ = metrics.compute_validity(sample_list, adj_s, mask_s)
|
||||
except:
|
||||
import pdb; pdb.set_trace()
|
||||
validity = 0.
|
||||
valid_arch_str = []
|
||||
print(f' ==> [Validity]: {round(validity, 4)}')
|
||||
|
||||
if len(valid_arch_str) > 0:
|
||||
all_samples += valid_arch_str
|
||||
print(f' ==> [# Unique Arch]: {len(set(all_samples))}')
|
||||
|
||||
if (len(set(all_samples)) >= num_samples):
|
||||
break
|
||||
|
||||
r += 1
|
||||
|
||||
return list(set(all_samples))[:num_samples]
|
||||
|
||||
|
||||
def noise_aware_meta_predictor_fit(config,
|
||||
predictor_model=None,
|
||||
xtrain=None,
|
||||
seed=None,
|
||||
sde=None,
|
||||
batch_size=5,
|
||||
epochs=50,
|
||||
save_best_p_corr=False,
|
||||
save_path=None,):
|
||||
assert save_best_p_corr
|
||||
reset_seed(seed)
|
||||
|
||||
data_loader = DataLoader(xtrain,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
|
||||
# create data normalizer and its inverse
|
||||
scaler = datasets_nas.get_data_scaler(config)
|
||||
|
||||
# Initialize model.
|
||||
optimizer = losses.get_optimizer(config, predictor_model.parameters())
|
||||
state = dict(optimizer=optimizer,
|
||||
model=predictor_model,
|
||||
step=0,
|
||||
config=config)
|
||||
|
||||
# Build one-step training and evaluation functions
|
||||
optimize_fn = losses.optimization_manager(config)
|
||||
continuous = config.training.continuous
|
||||
reduce_mean = config.training.reduce_mean
|
||||
likelihood_weighting = config.training.likelihood_weighting
|
||||
train_step_fn = losses.get_step_fn_predictor(sde, train=True, optimize_fn=optimize_fn,
|
||||
reduce_mean=reduce_mean, continuous=continuous,
|
||||
likelihood_weighting=likelihood_weighting,
|
||||
data=config.data.name, label_list=config.data.label_list,
|
||||
noised=config.training.noised,
|
||||
t_spot=config.training.t_spot,
|
||||
is_meta=True)
|
||||
|
||||
# temp
|
||||
# epochs = len(xtrain) * 100
|
||||
is_best = False
|
||||
best_p_corr = -1
|
||||
ckpt_dir = os.path.join(save_path, 'loop')
|
||||
print(f'==> Training for {epochs} epochs')
|
||||
for epoch in range(epochs):
|
||||
pred_list, labels_list = list(), list()
|
||||
for step, batch in enumerate(data_loader):
|
||||
x = batch['x'].to(config.device) # (5, 5, 20, 9)???
|
||||
adj = get_adj(config.data.name, config.data.except_inout)
|
||||
task = batch['task']
|
||||
extra = batch
|
||||
mask = aug_mask(adj,
|
||||
algo=config.data.aug_mask_algo,
|
||||
data=config.data.name)
|
||||
x = scaler(x.to(config.device))
|
||||
adj = adj.to(config.device)
|
||||
mask = mask.to(config.device)
|
||||
task = task.to(config.device)
|
||||
batch = (x, adj, mask, extra, task)
|
||||
# Execute one training step
|
||||
loss, pred, labels = train_step_fn(state, batch)
|
||||
pred_list += [v.detach().item() for v in pred.squeeze()]
|
||||
labels_list += [v.detach().item() for v in labels.squeeze()]
|
||||
p_corr = pearsonr(np.array(pred_list), np.array(labels_list))[0]
|
||||
s_corr = spearmanr(np.array(pred_list), np.array(labels_list))[0]
|
||||
if epoch % 50 == 0: print(f'==> [Epoch-{epoch}] P corr: {round(p_corr, 4)} | S corr: {round(s_corr, 4)}')
|
||||
|
||||
if save_best_p_corr:
|
||||
if p_corr > best_p_corr:
|
||||
is_best = True
|
||||
best_p_corr = p_corr
|
||||
os.makedirs(ckpt_dir, exist_ok=True)
|
||||
save_checkpoint(ckpt_dir, state, epoch, is_best)
|
||||
if save_best_p_corr:
|
||||
loaded_state = torch.load(os.path.join(ckpt_dir, 'model_best.pth.tar'), map_location=config.device)
|
||||
predictor_model.load_state_dict(loaded_state['model'])
|
||||
|
||||
|
||||
def save_checkpoint(ckpt_dir, state, epoch, is_best):
|
||||
saved_state = {}
|
||||
for k in state:
|
||||
if k in ['optimizer', 'model', 'ema']:
|
||||
saved_state.update({k: state[k].state_dict()})
|
||||
else:
|
||||
saved_state.update({k: state[k]})
|
||||
os.makedirs(ckpt_dir, exist_ok=True)
|
||||
torch.save(saved_state, os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth.tar'))
|
||||
if is_best:
|
||||
shutil.copy(os.path.join(ckpt_dir, f'checkpoint_{epoch}.pth.tar'), os.path.join(ckpt_dir, 'model_best.pth.tar'))
|
||||
# remove the ckpt except is_best state
|
||||
for ckpt_file in sorted(os.listdir(ckpt_dir)):
|
||||
if not ckpt_file.startswith('checkpoint'):
|
||||
continue
|
||||
if os.path.join(ckpt_dir, ckpt_file) != os.path.join(ckpt_dir, 'model_best.pth.tar'):
|
||||
os.remove(os.path.join(ckpt_dir, ckpt_file))
|
||||
|
||||
|
||||
def restore_checkpoint(ckpt_dir, state, device, resume=False):
|
||||
if not resume:
|
||||
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
|
||||
return state
|
||||
elif not os.path.exists(ckpt_dir):
|
||||
if not os.path.exists(os.path.dirname(ckpt_dir)):
|
||||
os.makedirs(os.path.dirname(ckpt_dir))
|
||||
logging.warning(f"No checkpoint found at {ckpt_dir}. "
|
||||
f"Returned the same state as input")
|
||||
return state
|
||||
else:
|
||||
loaded_state = torch.load(ckpt_dir, map_location=device)
|
||||
for k in state:
|
||||
if k in ['optimizer', 'model', 'ema']:
|
||||
state[k].load_state_dict(loaded_state[k])
|
||||
else:
|
||||
state[k] = loaded_state[k]
|
||||
return state
|
||||
63
MobileNetV3/main_exp/get_files/get_aircraft.py
Normal file
63
MobileNetV3/main_exp/get_files/get_aircraft.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
@author: Hayeon Lee
|
||||
2020/02/19
|
||||
Script for downloading, and reorganizing aircraft
|
||||
for few shot classification
|
||||
Run this file as follows:
|
||||
python get_data.py
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import os
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import tarfile
|
||||
from PIL import Image
|
||||
import glob
|
||||
import shutil
|
||||
import pickle
|
||||
import collections
|
||||
import sys
|
||||
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
|
||||
from all_path import RAW_DATA_PATH
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
dir_path = RAW_DATA_PATH
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
file_name = os.path.join(dir_path, 'fgvc-aircraft-2013b.tar.gz')
|
||||
|
||||
if not os.path.exists(file_name):
|
||||
print(f"Downloading {file_name}\n")
|
||||
download_file(
|
||||
'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz',
|
||||
file_name)
|
||||
print("\nDownloading done.\n")
|
||||
else:
|
||||
print("fgvc-aircraft-2013b.tar.gz has already been downloaded. Did not download twice.\n")
|
||||
|
||||
untar_file_name = os.path.join(dir_path, 'aircraft')
|
||||
if not os.path.exists(untar_file_name):
|
||||
tarname = file_name
|
||||
print("Untarring: {}".format(tarname))
|
||||
tar = tarfile.open(tarname)
|
||||
tar.extractall(untar_file_name)
|
||||
tar.close()
|
||||
else:
|
||||
print(f"{untar_file_name} folder already exists. Did not untarring twice\n")
|
||||
os.remove(file_name)
|
||||
50
MobileNetV3/main_exp/get_files/get_pets.py
Normal file
50
MobileNetV3/main_exp/get_files/get_pets.py
Normal file
@@ -0,0 +1,50 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import zipfile
|
||||
import sys
|
||||
sys.path.append(os.path.join(os.getcwd(), 'main_exp'))
|
||||
from all_path import RAW_DATA_PATH
|
||||
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm(unit="B", total=int(r.headers['Content-Length']))
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update(len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
|
||||
dir_path = os.path.join(RAW_DATA_PATH, 'pets')
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
full_name = os.path.join(dir_path, 'test15.pth')
|
||||
if not os.path.exists(full_name):
|
||||
print(f"Downloading {full_name}\n")
|
||||
download_file(
|
||||
'https://www.dropbox.com/s/kzmrwyyk5iaugv0/test15.pth?dl=1', full_name)
|
||||
print("Downloading done.\n")
|
||||
else:
|
||||
print(f"{full_name} has already been downloaded. Did not download twice.\n")
|
||||
|
||||
full_name = os.path.join(dir_path, 'train85.pth')
|
||||
if not os.path.exists(full_name):
|
||||
print(f"Downloading {full_name}\n")
|
||||
download_file(
|
||||
'https://www.dropbox.com/s/w7mikpztkamnw9s/train85.pth?dl=1', full_name)
|
||||
print("Downloading done.\n")
|
||||
else:
|
||||
print(f"{full_name} has already been downloaded. Did not download twice.\n")
|
||||
46
MobileNetV3/main_exp/get_files/get_preprocessed_data.py
Normal file
46
MobileNetV3/main_exp/get_files/get_preprocessed_data.py
Normal file
@@ -0,0 +1,46 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
from all_path import PROCESSED_DATA_PATH
|
||||
|
||||
dir_path = PROCESSED_DATA_PATH
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
|
||||
def get_preprocessed_data(file_name, url):
|
||||
print(f"Downloading {file_name} datasets\n")
|
||||
full_name = os.path.join(dir_path, file_name)
|
||||
download_file(url, full_name)
|
||||
print("Downloading done.\n")
|
||||
|
||||
|
||||
for file_name, url in [
|
||||
('aircraftbylabel.pt', 'https://www.dropbox.com/s/nn6mlrk1jijg108/aircraft100bylabel.pt?dl=1'),
|
||||
('cifar100bylabel.pt', 'https://www.dropbox.com/s/nn6mlrk1jijg108/aircraft100bylabel.pt?dl=1'),
|
||||
('cifar10bylabel.pt', 'https://www.dropbox.com/s/wt1pcwi991xyhwr/cifar10bylabel.pt?dl=1'),
|
||||
('imgnet32bylabel.pt', 'https://www.dropbox.com/s/7r3hpugql8qgi9d/imgnet32bylabel.pt?dl=1'),
|
||||
('petsbylabel.pt', 'https://www.dropbox.com/s/mxh6qz3grhy7wcn/petsbylabel.pt?dl=1'),
|
||||
]:
|
||||
|
||||
get_preprocessed_data(file_name, url)
|
||||
@@ -0,0 +1,44 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
|
||||
|
||||
DATA_PATH = "./data/ofa/data_score_model"
|
||||
dir_path = DATA_PATH
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
|
||||
def get_preprocessed_data(file_name, url):
|
||||
print(f"Downloading {file_name} datasets\n")
|
||||
full_name = os.path.join(dir_path, file_name)
|
||||
download_file(url, full_name)
|
||||
print("Downloading done.\n")
|
||||
|
||||
|
||||
for file_name, url in [
|
||||
('ofa_database_500000.pt', 'https://www.dropbox.com/scl/fi/0asz5qnvakf6ggucuynkk/ofa_database_500000.pt?rlkey=lqa1y4d6mikgzznevtanl2ybx&dl=1'),
|
||||
('ridx-500000.pt', 'https://www.dropbox.com/scl/fi/ambrm9n5efdkyydmsli0h/ridx-500000.pt?rlkey=b6iliyuiaxya4ropms8chsa7c&dl=1'),
|
||||
]:
|
||||
|
||||
get_preprocessed_data(file_name, url)
|
||||
390
MobileNetV3/main_exp/nag.py
Normal file
390
MobileNetV3/main_exp/nag.py
Normal file
@@ -0,0 +1,390 @@
|
||||
from __future__ import print_function
|
||||
import torch
|
||||
import os
|
||||
import gc
|
||||
import sys
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import time
|
||||
import os
|
||||
|
||||
from torch import optim
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from scipy.stats import pearsonr
|
||||
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.metad2a_utils import load_graph_config, decode_ofa_mbv3_str_to_igraph
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.metad2a_utils import get_log
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.metad2a_utils import save_model, mean_confidence_interval
|
||||
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.loader import get_meta_train_loader, MetaTestDataset
|
||||
|
||||
from transfer_nag_lib.encoder_FSBO_ofa import EncoderFSBO as PredictorModel
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.predictor import Predictor as MetaD2APredictor
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.train import train_single_model
|
||||
|
||||
from diffusion.run_lib import generate_archs
|
||||
from diffusion.run_lib import get_sampling_fn_meta
|
||||
from diffusion.run_lib import get_score_model
|
||||
from diffusion.run_lib import get_predictor
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd()))
|
||||
from all_path import *
|
||||
from utils import restore_checkpoint
|
||||
|
||||
|
||||
class NAG:
|
||||
def __init__(self, args, dgp_arch=[99, 50, 179, 194], bohb=False):
|
||||
self.args = args
|
||||
self.batch_size = args.batch_size
|
||||
self.num_sample = args.num_sample
|
||||
self.max_epoch = args.max_epoch
|
||||
self.save_epoch = args.save_epoch
|
||||
self.save_path = args.save_path
|
||||
self.search_space = args.search_space
|
||||
self.model_name = 'predictor'
|
||||
self.test = args.test
|
||||
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||
self.max_corr_dict = {'corr': -1, 'epoch': -1}
|
||||
self.train_arch = args.train_arch
|
||||
self.use_metad2a_predictor_selec = args.use_metad2a_predictor_selec
|
||||
|
||||
self.raw_data_path = RAW_DATA_PATH
|
||||
self.model_path = UNNOISE_META_PREDICTOR_CKPT_PATH
|
||||
self.data_path = PROCESSED_DATA_PATH
|
||||
self.classifier_ckpt_path = NOISE_META_PREDICTOR_CKPT_PATH
|
||||
self.load_diffusion_model(self.args.n_training_samples, args.pos_enc_type)
|
||||
|
||||
graph_config = load_graph_config(
|
||||
args.graph_data_name, args.nvt, self.data_path)
|
||||
|
||||
self.model = PredictorModel(args, graph_config, dgp_arch=dgp_arch)
|
||||
self.metad2a_model = MetaD2APredictor(args).model
|
||||
|
||||
if self.test:
|
||||
self.data_name = args.data_name
|
||||
self.num_class = args.num_class
|
||||
self.load_epoch = args.load_epoch
|
||||
self.n_training_samples = self.args.n_training_samples
|
||||
self.n_gen_samples = args.n_gen_samples
|
||||
self.folder_name = args.folder_name
|
||||
self.unique = args.unique
|
||||
|
||||
model_state_dict = self.model.state_dict()
|
||||
load_max_pt = 'ckpt_max_corr.pt'
|
||||
ckpt_path = os.path.join(self.model_path, load_max_pt)
|
||||
ckpt = torch.load(ckpt_path)
|
||||
for k, v in ckpt.items():
|
||||
if k in model_state_dict.keys():
|
||||
model_state_dict[k] = v
|
||||
self.model.cpu()
|
||||
self.model.load_state_dict(model_state_dict)
|
||||
self.model.to(self.device)
|
||||
|
||||
self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr)
|
||||
self.scheduler = ReduceLROnPlateau(self.optimizer, 'min',
|
||||
factor=0.1, patience=1000, verbose=True)
|
||||
self.mtrloader = get_meta_train_loader(
|
||||
self.batch_size, self.data_path, self.num_sample, is_pred=True)
|
||||
|
||||
self.acc_mean = self.mtrloader.dataset.mean
|
||||
self.acc_std = self.mtrloader.dataset.std
|
||||
|
||||
|
||||
def forward(self, x, arch, labels=None, train=False, matrix=False, metad2a=False):
|
||||
if metad2a:
|
||||
D_mu = self.metad2a_model.set_encode(x.to(self.device))
|
||||
G_mu = self.metad2a_model.graph_encode(arch)
|
||||
y_pred = self.metad2a_model.predict(D_mu, G_mu)
|
||||
return y_pred
|
||||
else:
|
||||
D_mu = self.model.set_encode(x.to(self.device))
|
||||
G_mu = self.model.graph_encode(arch, matrix=matrix)
|
||||
y_pred, y_dist = self.model.predict(D_mu, G_mu, labels=labels, train=train)
|
||||
return y_pred, y_dist
|
||||
|
||||
def meta_train(self):
|
||||
sttime = time.time()
|
||||
for epoch in range(1, self.max_epoch + 1):
|
||||
self.mtrlog.ep_sttime = time.time()
|
||||
loss, corr = self.meta_train_epoch(epoch)
|
||||
self.scheduler.step(loss)
|
||||
self.mtrlog.print_pred_log(loss, corr, 'train', epoch)
|
||||
valoss, vacorr = self.meta_validation(epoch)
|
||||
if self.max_corr_dict['corr'] < vacorr or epoch==1:
|
||||
self.max_corr_dict['corr'] = vacorr
|
||||
self.max_corr_dict['epoch'] = epoch
|
||||
self.max_corr_dict['loss'] = valoss
|
||||
save_model(epoch, self.model, self.model_path, max_corr=True)
|
||||
|
||||
self.mtrlog.print_pred_log(
|
||||
valoss, vacorr, 'valid', max_corr_dict=self.max_corr_dict)
|
||||
|
||||
if epoch % self.save_epoch == 0:
|
||||
save_model(epoch, self.model, self.model_path)
|
||||
|
||||
self.mtrlog.save_time_log()
|
||||
self.mtrlog.max_corr_log(self.max_corr_dict)
|
||||
|
||||
def meta_train_epoch(self, epoch):
|
||||
self.model.to(self.device)
|
||||
self.model.train()
|
||||
|
||||
self.mtrloader.dataset.set_mode('train')
|
||||
|
||||
dlen = len(self.mtrloader.dataset)
|
||||
trloss = 0
|
||||
y_all, y_pred_all = [], []
|
||||
pbar = tqdm(self.mtrloader)
|
||||
|
||||
for x, g, acc in pbar:
|
||||
self.optimizer.zero_grad()
|
||||
y_pred, y_dist = self.forward(x, g, labels=acc, train=True, matrix=False)
|
||||
y = acc.to(self.device).double()
|
||||
print(y.double())
|
||||
print(y_dist)
|
||||
loss = -self.model.mll(y_dist, y)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
y = y.tolist()
|
||||
y_pred = y_pred.squeeze().tolist()
|
||||
y_all += y
|
||||
y_pred_all += y_pred
|
||||
pbar.set_description(get_log(
|
||||
epoch, loss, y_pred, y, self.acc_std, self.acc_mean))
|
||||
trloss += float(loss)
|
||||
|
||||
return trloss / dlen, pearsonr(np.array(y_all),
|
||||
np.array(y_pred_all))[0]
|
||||
|
||||
def meta_validation(self, epoch):
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
valoss = 0
|
||||
self.mtrloader.dataset.set_mode('valid')
|
||||
dlen = len(self.mtrloader.dataset)
|
||||
y_all, y_pred_all = [], []
|
||||
pbar = tqdm(self.mtrloader)
|
||||
|
||||
with torch.no_grad():
|
||||
for x, g, acc in pbar:
|
||||
y_pred, y_dist = self.forward(x, g, labels=acc, train=False, matrix=False)
|
||||
y = acc.to(self.device)
|
||||
loss = -self.model.mll(y_dist, y)
|
||||
|
||||
y = y.tolist()
|
||||
y_pred = y_pred.squeeze().tolist()
|
||||
y_all += y
|
||||
y_pred_all += y_pred
|
||||
pbar.set_description(get_log(
|
||||
epoch, loss, y_pred, y, self.acc_std, self.acc_mean, tag='val'))
|
||||
valoss += float(loss)
|
||||
try:
|
||||
pearson_corr = pearsonr(np.array(y_all), np.array(y_pred_all))[0]
|
||||
except Exception as e:
|
||||
pearson_corr = 0
|
||||
|
||||
return valoss / dlen, pearson_corr
|
||||
|
||||
def meta_test(self):
|
||||
if self.data_name == 'all':
|
||||
for data_name in ['cifar10', 'cifar100', 'aircraft', 'pets']:
|
||||
acc = self.meta_test_per_dataset(data_name)
|
||||
else:
|
||||
acc = self.meta_test_per_dataset(self.data_name)
|
||||
return acc
|
||||
|
||||
|
||||
def meta_test_per_dataset(self, data_name):
|
||||
self.test_dataset = MetaTestDataset(
|
||||
self.data_path, data_name, self.num_sample, self.num_class)
|
||||
|
||||
meta_test_path = self.args.exp_name
|
||||
os.makedirs(meta_test_path, exist_ok=True)
|
||||
f_arch_str = open(os.path.join(meta_test_path, 'architecture.txt'), 'w')
|
||||
f = open(os.path.join(meta_test_path, 'accuracy.txt'), 'w')
|
||||
|
||||
elasped_time = []
|
||||
|
||||
print(f'==> select top architectures for {data_name} by meta-predictor...')
|
||||
|
||||
gen_arch_str = self.get_gen_arch_str()
|
||||
|
||||
gen_arch_igraph = [decode_ofa_mbv3_str_to_igraph(_) for _ in gen_arch_str]
|
||||
|
||||
y_pred_all = []
|
||||
self.metad2a_model.eval()
|
||||
self.metad2a_model.to(self.device)
|
||||
|
||||
# MetaD2A ver. prediction
|
||||
sttime = time.time()
|
||||
with torch.no_grad():
|
||||
for i, arch_igraph in enumerate(gen_arch_igraph):
|
||||
x, g = self.collect_data(arch_igraph)
|
||||
y_pred = self.forward(x, g, metad2a=True)
|
||||
y_pred = torch.mean(y_pred)
|
||||
y_pred_all.append(y_pred.cpu().detach().item())
|
||||
|
||||
if self.use_metad2a_predictor_selec:
|
||||
top_arch_lst = self.select_top_arch(
|
||||
data_name, torch.tensor(y_pred_all), gen_arch_str, self.n_training_samples)
|
||||
else:
|
||||
top_arch_lst = gen_arch_str[:self.n_training_samples]
|
||||
|
||||
elasped = time.time() - sttime
|
||||
elasped_time.append(elasped)
|
||||
|
||||
for _, arch_str in enumerate(top_arch_lst):
|
||||
f_arch_str.write(f'{arch_str}\n'); print(f'neural architecture config: {arch_str}')
|
||||
|
||||
support = top_arch_lst
|
||||
x_support = []
|
||||
y_support = []
|
||||
seeds = [777, 888, 999]
|
||||
y_support_per_seed = {
|
||||
_: [] for _ in seeds
|
||||
}
|
||||
net_info = {
|
||||
'params': [],
|
||||
'flops': [],
|
||||
}
|
||||
best_acc = 0.0
|
||||
best_sampe_num = 0
|
||||
|
||||
print("Data name: %s" % data_name)
|
||||
for i, arch_str in enumerate(support):
|
||||
save_path = os.path.join(meta_test_path, arch_str)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
acc_runs = []
|
||||
for seed in seeds:
|
||||
print(f'==> train for {data_name} {arch_str} ({seed})')
|
||||
valid_acc, max_valid_acc, params, flops = train_single_model(save_path=save_path,
|
||||
workers=8,
|
||||
datasets=data_name,
|
||||
xpaths=f'{self.raw_data_path}/{data_name}',
|
||||
splits=[0],
|
||||
use_less=False,
|
||||
seed=seed,
|
||||
model_str=arch_str,
|
||||
device='cuda',
|
||||
lr=0.01,
|
||||
momentum=0.9,
|
||||
weight_decay=4e-5,
|
||||
report_freq=50,
|
||||
epochs=20,
|
||||
grad_clip=5,
|
||||
cutout=True,
|
||||
cutout_length=16,
|
||||
autoaugment=True,
|
||||
drop=0.2,
|
||||
drop_path=0.2,
|
||||
img_size=224)
|
||||
acc_runs.append(valid_acc)
|
||||
y_support_per_seed[seed].append(valid_acc)
|
||||
|
||||
for r, acc in enumerate(acc_runs):
|
||||
msg = f'run {r + 1} {acc:.2f} (%)'
|
||||
f.write(msg + '\n')
|
||||
f.flush()
|
||||
print(msg)
|
||||
m, h = mean_confidence_interval(acc_runs)
|
||||
|
||||
if m > best_acc:
|
||||
best_acc = m
|
||||
best_sampe_num = i
|
||||
msg = f'Avg {m:.3f}+-{h.item():.2f} (%) (best acc {best_acc:.3f} - #{i})'
|
||||
f.write(msg + '\n')
|
||||
print(msg)
|
||||
y_support.append(np.mean(acc_runs))
|
||||
x_support.append(arch_str)
|
||||
net_info['params'].append(params)
|
||||
net_info['flops'].append(flops)
|
||||
torch.save({'y_support': y_support, 'x_support': x_support,
|
||||
'y_support_per_seed': y_support_per_seed,
|
||||
'net_info': net_info,
|
||||
'best_acc': best_acc,
|
||||
'best_sample_num': best_sampe_num},
|
||||
meta_test_path+'/result.pt')
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def train_single_arch(self, data_name, arch_str, meta_test_path):
|
||||
save_path = os.path.join(meta_test_path, arch_str)
|
||||
seeds = (777, 888, 999)
|
||||
train_single_model(save_path=save_path,
|
||||
workers=24,
|
||||
datasets=[data_name],
|
||||
xpaths=[f'{self.raw_data_path}/{data_name}'],
|
||||
splits=[0],
|
||||
use_less=False,
|
||||
seeds=seeds,
|
||||
model_str=arch_str,
|
||||
arch_config={'channel': 16, 'num_cells': 5})
|
||||
# Changed training time from 49/199
|
||||
epoch = 49 if data_name == 'mnist' else 199
|
||||
test_acc_lst = []
|
||||
for seed in seeds:
|
||||
result = torch.load(os.path.join(save_path, f'seed-0{seed}.pth'))
|
||||
test_acc_lst.append(result[data_name]['valid_acc1es'][f'x-test@{epoch}'])
|
||||
return test_acc_lst
|
||||
|
||||
|
||||
def select_top_arch(
|
||||
self, data_name, y_pred_all, gen_arch_str, N):
|
||||
_, sorted_idx = torch.sort(y_pred_all, descending=True)
|
||||
sotred_gen_arch_str = [gen_arch_str[_] for _ in sorted_idx]
|
||||
final_str = sotred_gen_arch_str[:N]
|
||||
return final_str
|
||||
|
||||
def collect_data_only(self):
|
||||
x_batch = []
|
||||
x_batch.append(self.test_dataset[0])
|
||||
return torch.stack(x_batch).to(self.device)
|
||||
|
||||
def collect_data(self, arch_igraph):
|
||||
x_batch, g_batch = [], []
|
||||
for _ in range(10):
|
||||
x_batch.append(self.test_dataset[0])
|
||||
g_batch.append(arch_igraph)
|
||||
return torch.stack(x_batch).to(self.device), g_batch
|
||||
|
||||
def load_diffusion_model(self, n_training_samples, pos_enc_type):
|
||||
self.config = torch.load(CONFIG_PATH)
|
||||
self.config.data.root = SCORE_MODEL_DATA_PATH
|
||||
self.config.scorenet_ckpt_path = SCORE_MODEL_CKPT_PATH
|
||||
torch.save(self.config, CONFIG_PATH)
|
||||
|
||||
self.sampling_fn, self.sde = get_sampling_fn_meta(self.config)
|
||||
self.sampling_fn_training_samples, _ = get_sampling_fn_meta(self.config, init=True, n_init=n_training_samples)
|
||||
self.score_model, self.score_ema, self.score_config \
|
||||
= get_score_model(self.config, pos_enc_type=pos_enc_type)
|
||||
|
||||
def get_gen_arch_str(self):
|
||||
classifier_config = torch.load(self.classifier_ckpt_path)['config']
|
||||
# Load meta-predictor
|
||||
classifier_model = get_predictor(classifier_config)
|
||||
classifier_state = dict(model=classifier_model, step=0, config=classifier_config)
|
||||
classifier_state = restore_checkpoint(self.classifier_ckpt_path,
|
||||
classifier_state, device=self.config.device, resume=True)
|
||||
print(f'==> load checkpoint for our predictor: {self.classifier_ckpt_path}...')
|
||||
|
||||
with torch.no_grad():
|
||||
x = self.collect_data_only()
|
||||
|
||||
generated_arch_str = generate_archs(
|
||||
self.config,
|
||||
self.sampling_fn,
|
||||
self.score_model,
|
||||
self.score_ema,
|
||||
classifier_model,
|
||||
num_samples=self.n_gen_samples,
|
||||
patient_factor=self.args.patient_factor,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
classifier_scale=self.args.classifier_scale,
|
||||
task=x if self.args.fix_task else None)
|
||||
|
||||
gc.collect()
|
||||
return generated_arch_str
|
||||
154
MobileNetV3/main_exp/run_transfer_nag.py
Normal file
154
MobileNetV3/main_exp/run_transfer_nag.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
import numpy as np
|
||||
import argparse
|
||||
import torch
|
||||
import os
|
||||
from nag import NAG
|
||||
# sys.path.append(os.getcwd())
|
||||
# from utils import str2bool
|
||||
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
return v.lower() in ['t', 'true', True]
|
||||
|
||||
# save_path = "results"
|
||||
# data_path = os.path.join('MetaD2A_nas_bench_201', 'data')
|
||||
# model_load_path = '/home/data/GTAD/baselines/transferNAS'
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
# general settings
|
||||
parser.add_argument('--seed', type=int, default=444)
|
||||
parser.add_argument('--gpu', type=str, default='0',
|
||||
help='set visible gpus')
|
||||
parser.add_argument('--search_space', type=str, default='ofa')
|
||||
parser.add_argument('--save-path', type=str,
|
||||
default=None, help='the path of save directory')
|
||||
parser.add_argument('--data-path', type=str,
|
||||
default=None, help='the path of save directory')
|
||||
parser.add_argument('--model-load-path', type=str,
|
||||
default=None, help='')
|
||||
parser.add_argument('--save-epoch', type=int, default=20,
|
||||
help='how many epochs to wait each time to save model states')
|
||||
parser.add_argument('--max-epoch', type=int, default=50,
|
||||
help='number of epochs to train')
|
||||
parser.add_argument('--batch_size', type=int,
|
||||
default=1024, help='batch size for generator')
|
||||
parser.add_argument('--graph-data-name',
|
||||
default='ofa', help='graph dataset name')
|
||||
parser.add_argument('--nvt', type=int, default=27,
|
||||
help='number of different node types')
|
||||
# set encoder
|
||||
parser.add_argument('--num-sample', type=int, default=20,
|
||||
help='the number of images as input for set encoder')
|
||||
# graph encoder
|
||||
parser.add_argument('--hs', type=int, default=512,
|
||||
help='hidden size of GRUs')
|
||||
parser.add_argument('--nz', type=int, default=56,
|
||||
help='the number of dimensions of latent vectors z')
|
||||
# test
|
||||
parser.add_argument('--test', action='store_true',
|
||||
default=True, help='turn on test mode')
|
||||
parser.add_argument('--load-epoch', type=int, default=100,
|
||||
help='checkpoint epoch loaded for meta-test')
|
||||
parser.add_argument('--data-name', type=str,
|
||||
default='pets', help='meta-test dataset name')
|
||||
parser.add_argument('--trials', type=int, default=5)
|
||||
|
||||
parser.add_argument('--num-class', type=int, default=None,
|
||||
help='the number of class of dataset')
|
||||
parser.add_argument('--num-gen-arch', type=int, default=500,
|
||||
help='the number of candidate architectures generated by the generator')
|
||||
parser.add_argument('--train-arch', type=str2bool, default=True,
|
||||
help='whether to train the searched architecture')
|
||||
parser.add_argument('--n_training_samples', type=int, default=5)
|
||||
parser.add_argument('--N', type=int, default=10)
|
||||
parser.add_argument('--use_gp', type=str2bool, default=False)
|
||||
parser.add_argument('--sorting', type=str2bool, default=True)
|
||||
parser.add_argument('--use_metad2a_predictor_selec', type=str2bool, default=True)
|
||||
parser.add_argument('--use_ensemble_selec', type=str2bool, default=False)
|
||||
|
||||
# ---------- For diffusion NAG ------------ #
|
||||
parser.add_argument('--folder_name', type=str, default='DiffusionNAG')
|
||||
parser.add_argument('--task', type=str, default='mtst')
|
||||
parser.add_argument('--exp_name', type=str, default='')
|
||||
parser.add_argument('--wandb_exp_name', type=str, default='')
|
||||
parser.add_argument('--wandb_project_name', type=str, default='DiffusionNAG')
|
||||
parser.add_argument('--use_wandb', type=str2bool, default=False)
|
||||
parser.add_argument('--classifier_scale', type=int, default=10000.0, help='classifier scale')
|
||||
parser.add_argument('--eval_batch_size', type=int, default=256)
|
||||
parser.add_argument('--predictor', type=str, default='euler_maruyama',
|
||||
choices=['euler_maruyama', 'reverse_diffusion', 'none'])
|
||||
parser.add_argument('--corrector', type=str, default='langevin',
|
||||
choices=['none', 'langevin'])
|
||||
parser.add_argument('--weight_ratio', type=str2bool, default=False)
|
||||
parser.add_argument('--weight_scheduling', type=str2bool, default=False)
|
||||
parser.add_argument('--weight_ratio_abs', type=str2bool, default=False)
|
||||
parser.add_argument('--p', type=int, default=1)
|
||||
parser.add_argument('--prod_w', type=str2bool, default=False)
|
||||
parser.add_argument('--t_spot', type=float, default=1.0)
|
||||
parser.add_argument('--t_spot_end', type=float, default=0.0)
|
||||
# Train
|
||||
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
|
||||
parser.add_argument('--epochs', type=int, default=500)
|
||||
parser.add_argument('--save_best_p_corr', type=str2bool, default=True)
|
||||
parser.add_argument('--unique', type=str2bool, default=True)
|
||||
parser.add_argument('--patient_factor', type=int, default=20)
|
||||
parser.add_argument('--n_gen_samples', type=int, default=50)
|
||||
################ OFA ####################
|
||||
parser.add_argument('--ofa_path', type=str, default='/home/hayeon/imagenet1k', help='')
|
||||
parser.add_argument('--ofa_batch_size', type=int, default=256, help='')
|
||||
parser.add_argument('--ofa_workers', type=int, default=4, help='')
|
||||
################ Diffusion ##############
|
||||
parser.add_argument('--diffusion_lr', type=float, default=1e-3, help='')
|
||||
parser.add_argument('--noise_aware_acc_norm', type=int, default=-1)
|
||||
parser.add_argument('--fix_task', type=str2bool, default=True)
|
||||
################ BO ####################
|
||||
parser.add_argument('--bo_loop_max_epoch', type=int, default=30)
|
||||
parser.add_argument('--bo_loop_acc_norm', type=int, default=1)
|
||||
parser.add_argument('--gp_model_acc_norm', type=int, default=1)
|
||||
parser.add_argument('--num_ensemble', type=int, default=3)
|
||||
parser.add_argument('--explore_type', type=str, default='ei')
|
||||
################ BO ####################
|
||||
# parser.add_argument('--multi_proc', type=str2bool, default=False)
|
||||
parser.add_argument('--eps', type=float, default=0.)
|
||||
parser.add_argument('--beta', type=float, default=0.5)
|
||||
parser.add_argument('--pos_enc_type', type=int, default=4)
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
def set_exp_name(args):
|
||||
exp_name = f'./exp/{args.task}/{args.folder_name}/data-{args.data_name}'
|
||||
wandb_exp_name = f'./exp/{args.task}/{args.folder_name}/{args.data_name}'
|
||||
|
||||
os.makedirs(exp_name, exist_ok=True)
|
||||
args.exp_name = exp_name
|
||||
args.wandb_exp_name = wandb_exp_name
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser()
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
|
||||
set_exp_name(args)
|
||||
|
||||
p = NAG(args)
|
||||
|
||||
if args.test:
|
||||
p.meta_test()
|
||||
else:
|
||||
p.meta_train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
100
MobileNetV3/main_exp/transfer_nag_lib/DeepKernelGPHelpers.py
Normal file
100
MobileNetV3/main_exp/transfer_nag_lib/DeepKernelGPHelpers.py
Normal file
@@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Tue Jul 6 14:02:53 2021
|
||||
|
||||
@author: hsjomaa
|
||||
"""
|
||||
import numpy as np
|
||||
from scipy.stats import norm
|
||||
import pandas as pd
|
||||
from torch import autograd as ag
|
||||
import torch
|
||||
from sklearn.preprocessing import PowerTransformer
|
||||
|
||||
|
||||
def regret(output,response):
|
||||
incumbent = output[0]
|
||||
best_output = []
|
||||
for _ in output:
|
||||
incumbent = _ if _ > incumbent else incumbent
|
||||
best_output.append(incumbent)
|
||||
opt = max(response)
|
||||
orde = list(np.sort(np.unique(response))[::-1])
|
||||
tmp = pd.DataFrame(best_output,columns=['regret_validation'])
|
||||
|
||||
tmp['rank_valid'] = tmp['regret_validation'].map(lambda x : orde.index(x))
|
||||
tmp['regret_validation'] = opt - tmp['regret_validation']
|
||||
return tmp
|
||||
|
||||
def EI(incumbent, model_fn,support,queries,return_variance, return_score=False):
|
||||
mu, stddev = model_fn(queries)
|
||||
mu = mu.reshape(-1,)
|
||||
stddev = stddev.reshape(-1,)
|
||||
if return_variance:
|
||||
stddev = np.sqrt(stddev)
|
||||
with np.errstate(divide='warn'):
|
||||
imp = mu - incumbent
|
||||
Z = imp / stddev
|
||||
score = imp * norm.cdf(Z) + stddev * norm.pdf(Z)
|
||||
if not return_score:
|
||||
score[support] = 0
|
||||
return np.argmax(score)
|
||||
else:
|
||||
return score
|
||||
|
||||
|
||||
class Metric(object):
|
||||
def __init__(self,prefix='train: '):
|
||||
self.reset()
|
||||
self.message=prefix + "loss: {loss:.2f} - noise: {log_var:.2f} - mse: {mse:.2f}"
|
||||
|
||||
def update(self,loss,noise,mse):
|
||||
self.loss.append(np.asscalar(loss))
|
||||
self.noise.append(np.asscalar(noise))
|
||||
self.mse.append(np.asscalar(mse))
|
||||
|
||||
def reset(self,):
|
||||
self.loss = []
|
||||
self.noise = []
|
||||
self.mse = []
|
||||
|
||||
def report(self):
|
||||
return self.message.format(loss=np.mean(self.loss),
|
||||
log_var=np.mean(self.noise),
|
||||
mse=np.mean(self.mse))
|
||||
|
||||
def get(self):
|
||||
return {"loss":np.mean(self.loss),
|
||||
"noise":np.mean(self.noise),
|
||||
"mse":np.mean(self.mse)}
|
||||
|
||||
def totorch(x,device):
|
||||
if type(x) is tuple:
|
||||
return tuple([ag.Variable(torch.Tensor(e)).to(device) for e in x])
|
||||
return torch.Tensor(x).to(device)
|
||||
|
||||
|
||||
def prepare_data(indexes, support, Lambda, response, metafeatures=None, output_transform=False):
|
||||
# Generate indexes of the batch
|
||||
X,E,Z,y,r = [],[],[],[],[]
|
||||
#### get support data
|
||||
for dim in indexes:
|
||||
if metafeatures is not None:
|
||||
Z.append(metafeatures)
|
||||
E.append(Lambda[support])
|
||||
X.append(Lambda[dim])
|
||||
r_ = response[support,np.newaxis]
|
||||
y_ = response[dim]
|
||||
if output_transform:
|
||||
power = PowerTransformer(method="yeo-johnson")
|
||||
r_ = power.fit_transform(r_)
|
||||
y_ = power.transform(y_.reshape(-1,1)).reshape(-1,)
|
||||
r.append(r_)
|
||||
y.append(y_)
|
||||
X = np.array(X)
|
||||
E = np.array(E)
|
||||
Z = np.array(Z)
|
||||
y = np.array(y)
|
||||
r = np.array(r)
|
||||
return (np.expand_dims(E, axis=-1), r, np.expand_dims(X, axis=-1), Z), y
|
||||
581
MobileNetV3/main_exp/transfer_nag_lib/DeepKernelGPModules.py
Normal file
581
MobileNetV3/main_exp/transfer_nag_lib/DeepKernelGPModules.py
Normal file
@@ -0,0 +1,581 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Tue Jul 6 14:03:42 2021
|
||||
|
||||
@author: hsjomaa
|
||||
"""
|
||||
## Original packages
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
import copy
|
||||
import numpy as np
|
||||
import os
|
||||
# from torch.utils.tensorboard import SummaryWriter
|
||||
import json
|
||||
import time
|
||||
## Our packages
|
||||
import gpytorch
|
||||
import logging
|
||||
from transfer_nag_lib.DeepKernelGPHelpers import totorch,prepare_data, Metric, EI
|
||||
from transfer_nag_lib.MetaD2A_nas_bench_201.generator import Generator
|
||||
from transfer_nag_lib.MetaD2A_nas_bench_201.main import get_parser
|
||||
np.random.seed(1203)
|
||||
RandomQueryGenerator= np.random.RandomState(413)
|
||||
RandomSupportGenerator= np.random.RandomState(413)
|
||||
RandomTaskGenerator = np.random.RandomState(413)
|
||||
|
||||
|
||||
class DeepKernelGP(nn.Module):
|
||||
|
||||
def __init__(self,X,Y,Z,kernel,backbone_fn, config, support,log_dir,seed):
|
||||
super(DeepKernelGP, self).__init__()
|
||||
torch.manual_seed(seed)
|
||||
## GP parameters
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.X,self.Y,self.Z = X,Y,Z
|
||||
self.feature_extractor = backbone_fn().to(self.device)
|
||||
self.config=config
|
||||
self.get_model_likelihood_mll(len(support),kernel,backbone_fn)
|
||||
|
||||
logging.basicConfig(filename=log_dir, level=logging.DEBUG)
|
||||
|
||||
def get_model_likelihood_mll(self, train_size,kernel,backbone_fn):
|
||||
|
||||
train_x=torch.ones(train_size, self.feature_extractor.out_features).to(self.device)
|
||||
train_y=torch.ones(train_size).to(self.device)
|
||||
|
||||
likelihood = gpytorch.likelihoods.GaussianLikelihood()
|
||||
model = ExactGPLayer(train_x=train_x, train_y=train_y, likelihood=likelihood, config=self.config,
|
||||
dims=self.feature_extractor.out_features)
|
||||
self.model = model.to(self.device)
|
||||
self.likelihood = likelihood.to(self.device)
|
||||
self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model).to(self.device)
|
||||
|
||||
def set_forward(self, x, is_feature=False):
|
||||
pass
|
||||
|
||||
def set_forward_loss(self, x):
|
||||
pass
|
||||
|
||||
def train(self, support, load_model,optimizer, checkpoint=None,epochs=1000, verbose = False):
|
||||
|
||||
if load_model:
|
||||
assert(checkpoint is not None)
|
||||
print("KEYS MATCHED")
|
||||
self.load_checkpoint(os.path.join(checkpoint,"weights"))
|
||||
|
||||
inputs,labels = prepare_data(support,support,self.X,self.Y,self.Z)
|
||||
inputs,labels = totorch(inputs,device=self.device), totorch(labels.reshape(-1,),device=self.device)
|
||||
losses = [np.inf]
|
||||
best_loss = np.inf
|
||||
starttime = time.time()
|
||||
initial_weights = copy.deepcopy(self.state_dict())
|
||||
patience=0
|
||||
max_patience = self.config["patience"]
|
||||
for _ in range(epochs):
|
||||
optimizer.zero_grad()
|
||||
z = self.feature_extractor(inputs)
|
||||
self.model.set_train_data(inputs=z, targets=labels)
|
||||
predictions = self.model(z)
|
||||
try:
|
||||
loss = -self.mll(predictions, self.model.train_targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
except Exception as ada:
|
||||
logging.info(f"Exception {ada}")
|
||||
break
|
||||
|
||||
if verbose:
|
||||
print("Iter {iter}/{epochs} - Loss: {loss:.5f} noise: {noise:.5f}".format(
|
||||
iter=_+1,epochs=epochs,loss=loss.item(),noise=self.likelihood.noise.item()))
|
||||
losses.append(loss.detach().to("cpu").item())
|
||||
if best_loss>losses[-1]:
|
||||
best_loss = losses[-1]
|
||||
weights = copy.deepcopy(self.state_dict())
|
||||
if np.allclose(losses[-1],losses[-2],atol=self.config["loss_tol"]):
|
||||
patience+=1
|
||||
else:
|
||||
patience=0
|
||||
if patience>max_patience:
|
||||
break
|
||||
self.load_state_dict(weights)
|
||||
logging.info(f"Current Iteration: {len(support)} | Incumbent {max(self.Y[support])} | Duration {np.round(time.time()-starttime)} | Epochs {_} | Noise {self.likelihood.noise.item()}")
|
||||
return losses,weights,initial_weights
|
||||
|
||||
def load_checkpoint(self, checkpoint):
|
||||
ckpt = torch.load(checkpoint,map_location=torch.device(self.device))
|
||||
self.model.load_state_dict(ckpt['gp'],strict=False)
|
||||
self.likelihood.load_state_dict(ckpt['likelihood'],strict=False)
|
||||
self.feature_extractor.load_state_dict(ckpt['net'],strict=False)
|
||||
|
||||
|
||||
def predict(self,support, query_range=None, noise_fn=None):
|
||||
|
||||
card = len(self.Y)
|
||||
if noise_fn:
|
||||
self.Y = noise_fn(self.Y)
|
||||
x_support,y_support = prepare_data(support,support,
|
||||
self.X,self.Y,self.Z)
|
||||
if query_range is None:
|
||||
x_query,_ = prepare_data(np.arange(card),support,
|
||||
self.X,self.Y,self.Z)
|
||||
else:
|
||||
x_query,_ = prepare_data(query_range,support,
|
||||
self.X,self.Y,self.Z)
|
||||
self.model.eval()
|
||||
self.feature_extractor.eval()
|
||||
self.likelihood.eval()
|
||||
|
||||
z_support = self.feature_extractor(totorch(x_support,self.device)).detach()
|
||||
self.model.set_train_data(inputs=z_support, targets=totorch(y_support.reshape(-1,),self.device), strict=False)
|
||||
|
||||
with torch.no_grad():
|
||||
z_query = self.feature_extractor(totorch(x_query,self.device)).detach()
|
||||
pred = self.likelihood(self.model(z_query))
|
||||
|
||||
|
||||
mu = pred.mean.detach().to("cpu").numpy().reshape(-1,)
|
||||
stddev = pred.stddev.detach().to("cpu").numpy().reshape(-1,)
|
||||
|
||||
return mu,stddev
|
||||
|
||||
class DKT(nn.Module):
|
||||
def __init__(self, train_data,valid_data, kernel,backbone_fn, config):
|
||||
super(DKT, self).__init__()
|
||||
## GP parameters
|
||||
self.train_data = train_data
|
||||
self.valid_data = valid_data
|
||||
self.fixed_context_size = config["fixed_context_size"]
|
||||
self.minibatch_size = config["minibatch_size"]
|
||||
self.n_inner_steps = config["n_inner_steps"]
|
||||
self.checkpoint_path = config["checkpoint_path"]
|
||||
os.makedirs(self.checkpoint_path,exist_ok=False)
|
||||
json.dump(config, open(os.path.join(self.checkpoint_path,"configuration.json"),"w"))
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logging.basicConfig(filename=os.path.join(self.checkpoint_path,"log.txt"), level=logging.DEBUG)
|
||||
self.feature_extractor = backbone_fn().to(self.device)
|
||||
self.config=config
|
||||
self.get_model_likelihood_mll(self.fixed_context_size,kernel,backbone_fn)
|
||||
self.mse = nn.MSELoss()
|
||||
self.curr_valid_loss = np.inf
|
||||
self.get_tasks()
|
||||
self.setup_writers()
|
||||
|
||||
self.train_metrics = Metric()
|
||||
self.valid_metrics = Metric(prefix="valid: ")
|
||||
print(self)
|
||||
|
||||
|
||||
def setup_writers(self,):
|
||||
train_log_dir = os.path.join(self.checkpoint_path,"train")
|
||||
os.makedirs(train_log_dir,exist_ok=True)
|
||||
self.train_summary_writer = SummaryWriter(train_log_dir)
|
||||
|
||||
valid_log_dir = os.path.join(self.checkpoint_path,"valid")
|
||||
os.makedirs(valid_log_dir,exist_ok=True)
|
||||
self.valid_summary_writer = SummaryWriter(valid_log_dir)
|
||||
|
||||
def get_tasks(self,):
|
||||
pairs = []
|
||||
for space in self.train_data.keys():
|
||||
for task in self.train_data[space].keys():
|
||||
pairs.append([space,task])
|
||||
self.tasks = pairs
|
||||
##########
|
||||
pairs = []
|
||||
for space in self.valid_data.keys():
|
||||
for task in self.valid_data[space].keys():
|
||||
pairs.append([space,task])
|
||||
self.valid_tasks = pairs
|
||||
|
||||
|
||||
def get_model_likelihood_mll(self, train_size,kernel,backbone_fn):
|
||||
|
||||
train_x=torch.ones(train_size, self.feature_extractor.out_features).to(self.device)
|
||||
train_y=torch.ones(train_size).to(self.device)
|
||||
|
||||
likelihood = gpytorch.likelihoods.GaussianLikelihood()
|
||||
model = ExactGPLayer(train_x=train_x, train_y=train_y, likelihood=likelihood, config=self.config,dims = self.feature_extractor.out_features)
|
||||
self.model = model.to(self.device)
|
||||
self.likelihood = likelihood.to(self.device)
|
||||
self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model).to(self.device)
|
||||
|
||||
def set_forward(self, x, is_feature=False):
|
||||
pass
|
||||
|
||||
def set_forward_loss(self, x):
|
||||
pass
|
||||
|
||||
def epoch_end(self):
|
||||
RandomTaskGenerator.shuffle(self.tasks)
|
||||
|
||||
def train_loop(self, epoch, optimizer, scheduler_fn=None):
|
||||
if scheduler_fn:
|
||||
scheduler = scheduler_fn(optimizer,len(self.tasks))
|
||||
self.epoch_end()
|
||||
assert(self.training)
|
||||
for task in self.tasks:
|
||||
inputs, labels = self.get_batch(task)
|
||||
for _ in range(self.n_inner_steps):
|
||||
optimizer.zero_grad()
|
||||
z = self.feature_extractor(inputs)
|
||||
self.model.set_train_data(inputs=z, targets=labels, strict=False)
|
||||
predictions = self.model(z)
|
||||
loss = -self.mll(predictions, self.model.train_targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
mse = self.mse(predictions.mean, labels)
|
||||
self.train_metrics.update(loss,self.model.likelihood.noise,mse)
|
||||
if scheduler_fn:
|
||||
scheduler.step()
|
||||
|
||||
training_results = self.train_metrics.get()
|
||||
for k,v in training_results.items():
|
||||
self.train_summary_writer.add_scalar(k, v, epoch)
|
||||
for task in self.valid_tasks:
|
||||
mse,loss = self.test_loop(task,train=False)
|
||||
self.valid_metrics.update(loss,np.array(0),mse,)
|
||||
|
||||
logging.info(self.train_metrics.report() + " " + self.valid_metrics.report())
|
||||
validation_results = self.valid_metrics.get()
|
||||
for k,v in validation_results.items():
|
||||
self.valid_summary_writer.add_scalar(k, v, epoch)
|
||||
self.feature_extractor.train()
|
||||
self.likelihood.train()
|
||||
self.model.train()
|
||||
|
||||
if validation_results["loss"] < self.curr_valid_loss:
|
||||
self.save_checkpoint(os.path.join(self.checkpoint_path,"weights"))
|
||||
self.curr_valid_loss = validation_results["loss"]
|
||||
self.valid_metrics.reset()
|
||||
self.train_metrics.reset()
|
||||
|
||||
def test_loop(self, task, train, optimizer=None): # no optimizer needed for GP
|
||||
(x_support, y_support),(x_query,y_query) = self.get_support_and_queries(task,train)
|
||||
z_support = self.feature_extractor(x_support).detach()
|
||||
self.model.set_train_data(inputs=z_support, targets=y_support, strict=False)
|
||||
self.model.eval()
|
||||
self.feature_extractor.eval()
|
||||
self.likelihood.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
z_query = self.feature_extractor(x_query).detach()
|
||||
pred = self.likelihood(self.model(z_query))
|
||||
loss = -self.mll(pred, y_query)
|
||||
lower, upper = pred.confidence_region() #2 standard deviations above and below the mean
|
||||
|
||||
mse = self.mse(pred.mean, y_query)
|
||||
|
||||
return mse,loss
|
||||
|
||||
def get_batch(self,task):
|
||||
# we want to fit the gp given context info to new observations
|
||||
# task is an algorithm/dataset pair
|
||||
space,task = task
|
||||
Lambda,response = np.array(self.train_data[space][task]["X"]), MinMaxScaler().fit_transform(np.array(self.train_data[space][task]["y"])).reshape(-1,)
|
||||
|
||||
card, dim = Lambda.shape
|
||||
|
||||
support = RandomSupportGenerator.choice(np.arange(card),
|
||||
replace=False,size=self.fixed_context_size)
|
||||
remaining = np.setdiff1d(np.arange(card),support)
|
||||
indexes = RandomQueryGenerator.choice(
|
||||
remaining,replace=False,size=self.minibatch_size if len(remaining)>self.minibatch_size else len(remaining))
|
||||
|
||||
inputs,labels = prepare_data(support,indexes,Lambda,response,np.zeros(32))
|
||||
inputs,labels = totorch(inputs,device=self.device), totorch(labels.reshape(-1,),device=self.device)
|
||||
return inputs, labels
|
||||
|
||||
def get_support_and_queries(self,task, train=False):
|
||||
|
||||
# task is an algorithm/dataset pair
|
||||
space,task = task
|
||||
|
||||
hpo_data = self.valid_data if not train else self.train_data
|
||||
Lambda,response = np.array(hpo_data[space][task]["X"]), MinMaxScaler().fit_transform(np.array(hpo_data[space][task]["y"])).reshape(-1,)
|
||||
card, dim = Lambda.shape
|
||||
|
||||
support = RandomSupportGenerator.choice(np.arange(card),
|
||||
replace=False,size=self.fixed_context_size)
|
||||
indexes = RandomQueryGenerator.choice(
|
||||
np.setdiff1d(np.arange(card),support),replace=False,size=self.minibatch_size)
|
||||
|
||||
support_x,support_y = prepare_data(support,support,Lambda,response,np.zeros(32))
|
||||
query_x,query_y = prepare_data(support,indexes,Lambda,response,np.zeros(32))
|
||||
|
||||
return (totorch(support_x,self.device),totorch(support_y.reshape(-1,),self.device)),\
|
||||
(totorch(query_x,self.device),totorch(query_y.reshape(-1,),self.device))
|
||||
|
||||
def save_checkpoint(self, checkpoint):
|
||||
# save state
|
||||
gp_state_dict = self.model.state_dict()
|
||||
likelihood_state_dict = self.likelihood.state_dict()
|
||||
nn_state_dict = self.feature_extractor.state_dict()
|
||||
torch.save({'gp': gp_state_dict, 'likelihood': likelihood_state_dict, 'net':nn_state_dict}, checkpoint)
|
||||
|
||||
def load_checkpoint(self, checkpoint):
|
||||
ckpt = torch.load(checkpoint)
|
||||
self.model.load_state_dict(ckpt['gp'])
|
||||
self.likelihood.load_state_dict(ckpt['likelihood'])
|
||||
self.feature_extractor.load_state_dict(ckpt['net'])
|
||||
|
||||
class ExactGPLayer(gpytorch.models.ExactGP):
|
||||
def __init__(self, train_x, train_y, likelihood,config,dims ):
|
||||
super(ExactGPLayer, self).__init__(train_x, train_y, likelihood)
|
||||
self.mean_module = gpytorch.means.ConstantMean()
|
||||
|
||||
## RBF kernel
|
||||
if(config["kernel"]=='rbf' or config["kernel"]=='RBF'):
|
||||
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=dims if config["ard"] else None))
|
||||
elif(config["kernel"]=='52'):
|
||||
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel(nu=config["nu"],ard_num_dims=dims if config["ard"] else None))
|
||||
## Spectral kernel
|
||||
else:
|
||||
raise ValueError("[ERROR] the kernel '" + str(config["kernel"]) + "' is not supported for regression, use 'rbf' or 'spectral'.")
|
||||
|
||||
def forward(self, x):
|
||||
mean_x = self.mean_module(x)
|
||||
covar_x = self.covar_module(x)
|
||||
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
|
||||
|
||||
|
||||
class batch_mlp(nn.Module):
|
||||
def __init__(self, d_in, output_sizes, nonlinearity="relu",dropout=0.0):
|
||||
|
||||
super(batch_mlp, self).__init__()
|
||||
assert(nonlinearity=="relu")
|
||||
self.nonlinearity = nn.ReLU()
|
||||
|
||||
self.fc = nn.ModuleList([nn.Linear(in_features=d_in, out_features=output_sizes[0])])
|
||||
for d_out in output_sizes[1:]:
|
||||
self.fc.append(nn.Linear(in_features=self.fc[-1].out_features, out_features=d_out))
|
||||
self.out_features = output_sizes[-1]
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
def forward(self,x):
|
||||
|
||||
for fc in self.fc[:-1]:
|
||||
x = fc(x)
|
||||
x = self.dropout(x)
|
||||
x = self.nonlinearity(x)
|
||||
x = self.fc[-1](x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
class StandardDeepGP(nn.Module):
|
||||
def __init__(self, configuration):
|
||||
|
||||
super(StandardDeepGP, self).__init__()
|
||||
self.A = batch_mlp(configuration["dim"], configuration["output_size_A"],dropout=configuration["dropout"])
|
||||
self.out_features = configuration["output_size_A"][-1]
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
# e,r,x,z = x
|
||||
hidden = self.A(x.squeeze(dim=-1)) ### NxA
|
||||
return hidden
|
||||
|
||||
|
||||
class DKTNAS(nn.Module):
|
||||
def __init__(self, kernel, backbone_fn, config, pretrained_encoder=True, GP_only=False):
|
||||
super(DKTNAS, self).__init__()
|
||||
## GP parameters
|
||||
|
||||
self.fixed_context_size = config["fixed_context_size"]
|
||||
self.minibatch_size = config["minibatch_size"]
|
||||
self.n_inner_steps = config["n_inner_steps"]
|
||||
self.set_encoder_args = get_parser()
|
||||
if not os.path.exists(self.set_encoder_args.save_path):
|
||||
os.makedirs(self.set_encoder_args.save_path)
|
||||
self.set_encoder_args.model_path = os.path.join(self.set_encoder_args.save_path,
|
||||
self.set_encoder_args.model_name, 'model')
|
||||
if not os.path.exists(self.set_encoder_args.model_path):
|
||||
os.makedirs(self.set_encoder_args.model_path)
|
||||
self.set_encoder = Generator(self.set_encoder_args)
|
||||
if pretrained_encoder:
|
||||
self.dataset_enc, self.arch, self.acc = self.set_encoder.train_dgp(encode=False)
|
||||
self.dataset_enc_val, self.acc_val = self.set_encoder.test_dgp(data_name='cifar100', encode=False)
|
||||
else: # In case we want to train the set-encoder from scratch
|
||||
self.dataset_enc = np.load("train_data_path.npy")
|
||||
self.acc = np.load("train_acc.npy")
|
||||
self.dataset_enc_val = np.load("cifar100_data_path.npy")
|
||||
self.acc_val = np.load("cifar100_acc.npy")
|
||||
self.valid_data = self.dataset_enc_val
|
||||
self.checkpoint_path = config["checkpoint_path"]
|
||||
os.makedirs(self.checkpoint_path, exist_ok=False)
|
||||
json.dump(config, open(os.path.join(self.checkpoint_path, "configuration.json"), "w"))
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logging.basicConfig(filename=os.path.join(self.checkpoint_path, "log.txt"), level=logging.DEBUG)
|
||||
self.feature_extractor = backbone_fn().to(self.device)
|
||||
self.config = config
|
||||
self.GP_only = GP_only
|
||||
self.get_model_likelihood_mll(self.fixed_context_size, kernel, backbone_fn)
|
||||
self.mse = nn.MSELoss()
|
||||
self.curr_valid_loss = np.inf
|
||||
# self.get_tasks()
|
||||
self.setup_writers()
|
||||
|
||||
self.train_metrics = Metric()
|
||||
self.valid_metrics = Metric(prefix="valid: ")
|
||||
self.tasks = len(self.dataset_enc)
|
||||
|
||||
print(self)
|
||||
|
||||
def setup_writers(self, ):
|
||||
train_log_dir = os.path.join(self.checkpoint_path, "train")
|
||||
os.makedirs(train_log_dir, exist_ok=True)
|
||||
# self.train_summary_writer = SummaryWriter(train_log_dir)
|
||||
|
||||
valid_log_dir = os.path.join(self.checkpoint_path, "valid")
|
||||
os.makedirs(valid_log_dir, exist_ok=True)
|
||||
# self.valid_summary_writer = SummaryWriter(valid_log_dir)
|
||||
|
||||
|
||||
def get_model_likelihood_mll(self, train_size, kernel, backbone_fn):
|
||||
if not self.GP_only:
|
||||
train_x = torch.ones(train_size, self.feature_extractor.out_features).to(self.device)
|
||||
train_y = torch.ones(train_size).to(self.device)
|
||||
|
||||
likelihood = gpytorch.likelihoods.GaussianLikelihood()
|
||||
|
||||
model = ExactGPLayer(train_x=None, train_y=None, likelihood=likelihood, config=self.config,
|
||||
dims=self.feature_extractor.out_features)
|
||||
else:
|
||||
train_x = torch.ones(train_size, self.fixed_context_size).to(self.device)
|
||||
train_y = torch.ones(train_size).to(self.device)
|
||||
|
||||
likelihood = gpytorch.likelihoods.GaussianLikelihood()
|
||||
|
||||
model = ExactGPLayer(train_x=None, train_y=None, likelihood=likelihood, config=self.config,
|
||||
dims=self.fixed_context_size)
|
||||
self.model = model.to(self.device)
|
||||
self.likelihood = likelihood.to(self.device)
|
||||
self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model).to(self.device)
|
||||
|
||||
def set_forward(self, x, is_feature=False):
|
||||
pass
|
||||
|
||||
def set_forward_loss(self, x):
|
||||
pass
|
||||
|
||||
def epoch_end(self):
|
||||
RandomTaskGenerator.shuffle([1])
|
||||
|
||||
def train_loop(self, epoch, optimizer, scheduler_fn=None):
|
||||
if scheduler_fn:
|
||||
scheduler = scheduler_fn(optimizer, 1)
|
||||
self.epoch_end()
|
||||
assert (self.training)
|
||||
for task in range(self.tasks):
|
||||
inputs, labels = self.get_batch(task)
|
||||
for _ in range(self.n_inner_steps):
|
||||
optimizer.zero_grad()
|
||||
z = self.feature_extractor(inputs)
|
||||
self.model.set_train_data(inputs=z, targets=labels, strict=False)
|
||||
predictions = self.model(z)
|
||||
loss = -self.mll(predictions, self.model.train_targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
mse = self.mse(predictions.mean, labels)
|
||||
self.train_metrics.update(loss, self.model.likelihood.noise, mse)
|
||||
if scheduler_fn:
|
||||
scheduler.step()
|
||||
|
||||
training_results = self.train_metrics.get()
|
||||
for k, v in training_results.items():
|
||||
self.train_summary_writer.add_scalar(k, v, epoch)
|
||||
mse, loss = self.test_loop(train=False)
|
||||
self.valid_metrics.update(loss, np.array(0), mse, )
|
||||
|
||||
logging.info(self.train_metrics.report() + " " + self.valid_metrics.report())
|
||||
validation_results = self.valid_metrics.get()
|
||||
for k, v in validation_results.items():
|
||||
self.valid_summary_writer.add_scalar(k, v, epoch)
|
||||
self.feature_extractor.train()
|
||||
self.likelihood.train()
|
||||
self.model.train()
|
||||
|
||||
if validation_results["loss"] < self.curr_valid_loss:
|
||||
self.save_checkpoint(os.path.join(self.checkpoint_path, "weights"))
|
||||
self.curr_valid_loss = validation_results["loss"]
|
||||
self.valid_metrics.reset()
|
||||
self.train_metrics.reset()
|
||||
|
||||
def test_loop(self, train=None, optimizer=None): # no optimizer needed for GP
|
||||
(x_support, y_support), (x_query, y_query) = self.get_support_and_queries(train)
|
||||
z_support = self.feature_extractor(x_support).detach()
|
||||
self.model.set_train_data(inputs=z_support, targets=y_support, strict=False)
|
||||
self.model.eval()
|
||||
self.feature_extractor.eval()
|
||||
self.likelihood.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
z_query = self.feature_extractor(x_query).detach()
|
||||
pred = self.likelihood(self.model(z_query))
|
||||
loss = -self.mll(pred, y_query)
|
||||
lower, upper = pred.confidence_region() # 2 standard deviations above and below the mean
|
||||
|
||||
mse = self.mse(pred.mean, y_query)
|
||||
|
||||
return mse, loss
|
||||
|
||||
def get_batch(self, task, valid=False):
|
||||
|
||||
# we want to fit the gp given context info to new observations
|
||||
#TODO: scale the response as in FSBO(needed for train)
|
||||
Lambda, response = np.array(self.dataset_enc), np.array(self.acc)
|
||||
|
||||
inputs, labels = Lambda[task], response[task]
|
||||
inputs, labels = totorch([inputs], device=self.device), totorch([labels], device=self.device)
|
||||
return inputs, labels
|
||||
|
||||
def get_support_and_queries(self, task, train=False):
|
||||
|
||||
# TODO: scale the response as in FSBO(not necessary for test)
|
||||
Lambda, response = np.array(self.dataset_enc_val), np.array(self.acc_val)
|
||||
card, dim = Lambda.shape
|
||||
|
||||
support = RandomSupportGenerator.choice(np.arange(card),
|
||||
replace=False, size=self.fixed_context_size)
|
||||
indexes = RandomQueryGenerator.choice(
|
||||
np.setdiff1d(np.arange(card), support), replace=False, size=self.minibatch_size)
|
||||
|
||||
support_x, support_y = Lambda[support], response[support]
|
||||
query_x, query_y = Lambda[indexes], response[indexes]
|
||||
|
||||
return (totorch(support_x, self.device), totorch(support_y.reshape(-1, ), self.device)), \
|
||||
(totorch(query_x, self.device), totorch(query_y.reshape(-1, ), self.device))
|
||||
|
||||
def save_checkpoint(self, checkpoint):
|
||||
# save state
|
||||
gp_state_dict = self.model.state_dict()
|
||||
likelihood_state_dict = self.likelihood.state_dict()
|
||||
nn_state_dict = self.feature_extractor.state_dict()
|
||||
torch.save({'gp': gp_state_dict, 'likelihood': likelihood_state_dict, 'net': nn_state_dict}, checkpoint)
|
||||
|
||||
def load_checkpoint(self, checkpoint):
|
||||
ckpt = torch.load(checkpoint)
|
||||
self.model.load_state_dict(ckpt['gp'])
|
||||
self.likelihood.load_state_dict(ckpt['likelihood'])
|
||||
self.feature_extractor.load_state_dict(ckpt['net'])
|
||||
|
||||
def predict(self, x_support, y_support, x_query, y_query, GP_only=False):
|
||||
if not GP_only:
|
||||
z_support = self.feature_extractor(x_support).detach()
|
||||
else:
|
||||
z_support = x_support
|
||||
self.model.set_train_data(inputs=z_support, targets=y_support, strict=False)
|
||||
self.model.eval()
|
||||
self.feature_extractor.eval()
|
||||
self.likelihood.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
if not GP_only:
|
||||
z_query = self.feature_extractor(x_query).detach()
|
||||
else:
|
||||
z_query = x_query
|
||||
pred = self.likelihood(self.model(z_query))
|
||||
mu = pred.mean.detach().to("cpu").numpy().reshape(-1, )
|
||||
stddev = pred.stddev.detach().to("cpu").numpy().reshape(-1, )
|
||||
return mu, stddev
|
||||
@@ -0,0 +1,168 @@
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets
|
||||
This code is for MobileNetV3 Search Space experiments
|
||||
|
||||
|
||||
## Prerequisites
|
||||
- Python 3.6 (Anaconda)
|
||||
- PyTorch 1.6.0
|
||||
- CUDA 10.2
|
||||
- python-igraph==0.8.2
|
||||
- tqdm==4.50.2
|
||||
- torchvision==0.7.0
|
||||
- python-igraph==0.8.2
|
||||
- scipy==1.5.2
|
||||
- ofa==0.0.4-2007200808
|
||||
|
||||
|
||||
## MobileNetV3 Search Space
|
||||
Go to the folder for MobileNetV3 experiments (i.e. ```MetaD2A_mobilenetV3```)
|
||||
|
||||
The overall flow is summarized as follows:
|
||||
- Building database for Predictor
|
||||
- Meta-Training Predictor
|
||||
- Building database for Generator with trained Predictor
|
||||
- Meta-Training Generator
|
||||
- Meta-Testing (Searching)
|
||||
- Evaluating the Searched architecture
|
||||
|
||||
|
||||
## Data Preparation
|
||||
To download preprocessed data files, run ```get_files/get_preprocessed_data.py```:
|
||||
```shell script
|
||||
$ python get_files/get_preprocessed_data.py
|
||||
```
|
||||
It will take some time to download and preprocess each dataset.
|
||||
|
||||
|
||||
## Meta Test and Evaluation
|
||||
### Meta-Test
|
||||
|
||||
You can download trained checkpoint files for generator and predictor
|
||||
```shell script
|
||||
$ python get_files/get_generator_checkpoint.py
|
||||
$ python get_files/get_predictor_checkpoint.py
|
||||
```
|
||||
|
||||
If you want to meta-test with your own dataset, please first make your own preprocessed data,
|
||||
by modifying ```process_dataset.py``` .
|
||||
```shell script
|
||||
$ process_dataset.py
|
||||
```
|
||||
|
||||
This code automatically generates neural architecturess and then
|
||||
selects high-performing architectures among the candidates.
|
||||
By setting ```--data-name``` as the name of dataset (i.e. ```cifar10```, ```cifar100```, ```aircraft100```, ```pets```),
|
||||
you can evaluate the specific dataset.
|
||||
|
||||
```shell script
|
||||
# Meta-testing
|
||||
$ python main.py --gpu 0 --model generator --hs 56 --nz 56 --test --load-epoch 120 --num-gen-arch 200 --data-name {DATASET_NAME}
|
||||
```
|
||||
|
||||
### Arhictecture Evaluation (MetaD2A vs NSGANetV2)
|
||||
##### Dataset Preparation
|
||||
You need to download Oxford-IIIT Pet dataset to evaluate on ```--data-name pets```
|
||||
```shell script
|
||||
$ python get_files/get_pets.py
|
||||
```
|
||||
Every others ```cifar10```, ```cifar100```, ```aircraft100``` will be downloaded automatically.
|
||||
|
||||
##### evaluation
|
||||
You can run the searched architecture by running ```evaluation/main```. Codes are based on NSGANetV2.
|
||||
|
||||
Go to the evaluation folder (i.e. ```evaluation```)
|
||||
```shell script
|
||||
$ cd evaluation
|
||||
```
|
||||
|
||||
This automatically run the top 1 predicted architecture derived by MetaD2A.
|
||||
```shell script
|
||||
python main.py --data-name cifar10 --num-gen-arch 200
|
||||
```
|
||||
You can also give flop constraint by using ```bound``` option.
|
||||
```shell script
|
||||
python main.py --data-name cifar10 --num-gen-arch 200 --bound 300
|
||||
```
|
||||
|
||||
You can compare MetaD2A with NSGANetV2
|
||||
but you need to download some files provided
|
||||
by [NSGANetV2](https://github.com/human-analysis/nsganetv2)
|
||||
|
||||
```shell script
|
||||
python main.py --data-name cifar10 --num-gen-arch 200 --model-config flops@232
|
||||
```
|
||||
|
||||
|
||||
## Meta-Training MetaD2A Model
|
||||
To build database for Meta-training, you need to set ```IMGNET_PATH```, which is a directory of ILSVRC2021.
|
||||
|
||||
### Database Building for Predictor
|
||||
We recommend you to run the multiple ```create_database.sh``` simultaneously to build fast.
|
||||
You need to set ```IMGNET_PATH``` in the shell script.
|
||||
```shell script
|
||||
# Examples
|
||||
bash create_database.sh 0,1,2,3 0 49 predictor
|
||||
bash create_database.sh all 50 99 predictor
|
||||
...
|
||||
```
|
||||
After enough dataset is gathered, run ```build_database.py``` to collect them as one file.
|
||||
```shell script
|
||||
python build_database.py --model_name predictor --collect
|
||||
```
|
||||
|
||||
We also provide the database we use. To download database, run ```get_files/get_predictor_database.py```:
|
||||
```shell script
|
||||
$ python get_files/get_predictor_database.py
|
||||
```
|
||||
|
||||
### Meta-Train Predictor
|
||||
You can train the predictor as follows
|
||||
```shell script
|
||||
# Meta-training for predictor
|
||||
$ python main.py --gpu 0 --model predictor --hs 512 --nz 56
|
||||
```
|
||||
### Database Building for Generator
|
||||
We recommend you to run the multiple ```create_database.sh``` simultaneously to build fast.
|
||||
```shell script
|
||||
# Examples
|
||||
bash create_database.sh 4,5,6,7 0 49 generator
|
||||
bash create_database.sh all 50 99 generator
|
||||
...
|
||||
```
|
||||
After enough dataset is gathered, run ```build_database.py``` to collect them as one.
|
||||
```shell script
|
||||
python build_database.py --model_name generator --collect
|
||||
```
|
||||
|
||||
We also provide the database we use. To download database, run ```get_files/get_generator_database.py```
|
||||
```shell script
|
||||
$ python get_files/get_generator_database.py
|
||||
```
|
||||
|
||||
|
||||
### Meta-Train Generator
|
||||
You can train the generator as follows
|
||||
```shell script
|
||||
# Meta-training for generator
|
||||
$ python main.py --gpu 0 --model generator --hs 56 --nz 56
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Citation
|
||||
If you found the provided code useful, please cite our work.
|
||||
```
|
||||
@inproceedings{
|
||||
lee2021rapid,
|
||||
title={Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets},
|
||||
author={Hayeon Lee and Eunyoung Hyung and Sung Ju Hwang},
|
||||
booktitle={ICLR},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
## Reference
|
||||
- [Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks (ICML2019)](https://github.com/juho-lee/set_transformer)
|
||||
- [D-VAE: A Variational Autoencoder for Directed Acyclic Graphs, Advances in Neural Information Processing Systems (NeurIPS2019)](https://github.com/muhanzhang/D-VAE)
|
||||
- [Once for All: Train One Network and Specialize it for Efficient Deployment (ICLR2020)](https://github.com/mit-han-lab/once-for-all)
|
||||
- [NSGANetV2: Evolutionary Multi-Objective Surrogate-Assisted Neural Architecture Search (ECCV2020)](https://github.com/human-analysis/nsganetv2)
|
||||
@@ -0,0 +1,49 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from parser import get_parser
|
||||
from predictor import PredictorModel
|
||||
from database import DatabaseOFA
|
||||
from utils import load_graph_config
|
||||
|
||||
def main():
|
||||
args = get_parser()
|
||||
|
||||
if args.gpu == 'all':
|
||||
device_list = range(torch.cuda.device_count())
|
||||
args.gpu = ','.join(str(_) for _ in device_list)
|
||||
else:
|
||||
device_list = [int(_) for _ in args.gpu.split(',')]
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
||||
args.device = torch.device("cuda:0")
|
||||
args.batch_size = args.batch_size * max(len(device_list), 1)
|
||||
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
|
||||
args.model_path = os.path.join(args.save_path, args.model_name, 'model')
|
||||
|
||||
if args.model_name == 'generator':
|
||||
graph_config = load_graph_config(
|
||||
args.graph_data_name, args.nvt, args.data_path)
|
||||
model = PredictorModel(args, graph_config)
|
||||
d = DatabaseOFA(args, model)
|
||||
else:
|
||||
d = DatabaseOFA(args)
|
||||
|
||||
if args.collect:
|
||||
d.collect_db()
|
||||
else:
|
||||
assert args.index is not None
|
||||
assert args.imgnet is not None
|
||||
d.make_db()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,15 @@
|
||||
#bash create_database.sh all predictor 0 49
|
||||
|
||||
IMGNET_PATH='/w14/dataset/ILSVRC2012' # PUT YOUR ILSVRC2012 DIR
|
||||
|
||||
for ((ind=$2;ind<=$3;ind++))
|
||||
do
|
||||
python build_database.py --gpu $1 \
|
||||
--model_name $4 \
|
||||
--index $ind \
|
||||
--imgnet $IMGNET_PATH \
|
||||
--hs 512 \
|
||||
--nz 56
|
||||
done
|
||||
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from .db_ofa import DatabaseOFA
|
||||
@@ -0,0 +1,57 @@
|
||||
######################################################################################
|
||||
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
__all__ = ['DataProvider']
|
||||
|
||||
|
||||
class DataProvider:
|
||||
SUB_SEED = 937162211 # random seed for sampling subset
|
||||
VALID_SEED = 2147483647 # random seed for the validation set
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
""" Return name of the dataset """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
""" Return shape as python list of one data entry """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
""" Return `int` of num classes """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
""" local path to save the data """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
""" link to download the data """
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def random_sample_valid_set(train_size, valid_size):
|
||||
assert train_size > valid_size
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.VALID_SEED) # set random seed before sampling validation set
|
||||
rand_indexes = torch.randperm(train_size, generator=g).tolist()
|
||||
|
||||
valid_indexes = rand_indexes[:valid_size]
|
||||
train_indexes = rand_indexes[valid_size:]
|
||||
return train_indexes, valid_indexes
|
||||
|
||||
@staticmethod
|
||||
def labels_to_one_hot(n_classes, labels):
|
||||
new_labels = np.zeros((labels.shape[0], n_classes), dtype=np.float32)
|
||||
new_labels[range(labels.shape[0]), labels] = np.ones(labels.shape)
|
||||
return new_labels
|
||||
@@ -0,0 +1,107 @@
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import copy
|
||||
import glob
|
||||
from .imagenet import ImagenetDataProvider
|
||||
from .imagenet_loader import ImagenetRunConfig
|
||||
from .run_manager import RunManager
|
||||
from ofa.model_zoo import ofa_net
|
||||
|
||||
|
||||
class DatabaseOFA:
|
||||
def __init__(self, args, predictor=None):
|
||||
self.path = f'{args.data_path}/{args.model_name}'
|
||||
self.model_name = args.model_name
|
||||
self.index = args.index
|
||||
self.args = args
|
||||
self.predictor = predictor
|
||||
ImagenetDataProvider.DEFAULT_PATH = args.imgnet
|
||||
|
||||
if not os.path.exists(self.path):
|
||||
os.makedirs(self.path)
|
||||
|
||||
def make_db(self):
|
||||
self.ofa_network = ofa_net('ofa_mbv3_d234_e346_k357_w1.0', pretrained=True)
|
||||
self.run_config = ImagenetRunConfig(test_batch_size=self.args.batch_size,
|
||||
n_worker=20)
|
||||
database = []
|
||||
st_time = time.time()
|
||||
f = open(f'{self.path}/txt_{self.index}.txt', 'w')
|
||||
for dn in range(10000):
|
||||
best_pp = -1
|
||||
best_info = None
|
||||
dls = None
|
||||
with torch.no_grad():
|
||||
if self.model_name == 'generator':
|
||||
for i in range(10):
|
||||
net_setting = self.ofa_network.sample_active_subnet()
|
||||
subnet = self.ofa_network.get_active_subnet(preserve_weight=True)
|
||||
if i == 0:
|
||||
run_manager = RunManager('.tmp/eval_subnet', self.args, subnet,
|
||||
self.run_config, init=False, pp=self.predictor)
|
||||
self.run_config.data_provider.assign_active_img_size(224)
|
||||
dls = {j: copy.deepcopy(run_manager.data_loader) for j in range(1, 10)}
|
||||
else:
|
||||
run_manager = RunManager('.tmp/eval_subnet', self.args, subnet,
|
||||
self.run_config,
|
||||
init=False, data_loader=dls[i], pp=self.predictor)
|
||||
run_manager.reset_running_statistics(net=subnet)
|
||||
|
||||
loss, (top1, top5), pred_acc \
|
||||
= run_manager.validate(net=subnet, net_setting=net_setting)
|
||||
|
||||
if best_pp < pred_acc:
|
||||
best_pp = pred_acc
|
||||
print('[%d] class=%d,\t loss=%.5f,\t top1=%.1f,\t top5=%.1f' % (
|
||||
dn, len(run_manager.cls_lst), loss, top1, top5))
|
||||
info_dict = {'loss': loss,
|
||||
'top1': top1,
|
||||
'top5': top5,
|
||||
'net': net_setting,
|
||||
'class': run_manager.cls_lst,
|
||||
'params': run_manager.net_info['params'],
|
||||
'flops': run_manager.net_info['flops'],
|
||||
'test_transform': run_manager.test_transform
|
||||
}
|
||||
best_info = info_dict
|
||||
elif self.model_name == 'predictor':
|
||||
net_setting = self.ofa_network.sample_active_subnet()
|
||||
subnet = self.ofa_network.get_active_subnet(preserve_weight=True)
|
||||
run_manager = RunManager('.tmp/eval_subnet', self.args, subnet, self.run_config, init=False)
|
||||
self.run_config.data_provider.assign_active_img_size(224)
|
||||
run_manager.reset_running_statistics(net=subnet)
|
||||
|
||||
loss, (top1, top5), _ = run_manager.validate(net=subnet)
|
||||
print('[%d] class=%d,\t loss=%.5f,\t top1=%.1f,\t top5=%.1f' % (
|
||||
dn, len(run_manager.cls_lst), loss, top1, top5))
|
||||
best_info = {'loss': loss,
|
||||
'top1': top1,
|
||||
'top5': top5,
|
||||
'net': net_setting,
|
||||
'class': run_manager.cls_lst,
|
||||
'params': run_manager.net_info['params'],
|
||||
'flops': run_manager.net_info['flops'],
|
||||
'test_transform': run_manager.test_transform
|
||||
}
|
||||
database.append(best_info)
|
||||
if (len(database)) % 10 == 0:
|
||||
msg = f'{(time.time() - st_time) / 60.0:0.2f}(min) save {len(database)} database, {self.index} id'
|
||||
print(msg)
|
||||
f.write(msg + '\n')
|
||||
f.flush()
|
||||
torch.save(database, f'{self.path}/database_{self.index}.pt')
|
||||
|
||||
def collect_db(self):
|
||||
if not os.path.exists(self.path + f'/processed'):
|
||||
os.makedirs(self.path + f'/processed')
|
||||
|
||||
database = []
|
||||
dlst = glob.glob(self.path + '/*.pt')
|
||||
for filepath in dlst:
|
||||
database += torch.load(filepath)
|
||||
|
||||
assert len(database) != 0
|
||||
|
||||
print(f'The number of database: {len(database)}')
|
||||
torch.save(database, self.path + f'/processed/collected_database.pt')
|
||||
@@ -0,0 +1,240 @@
|
||||
######################################################################################
|
||||
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
import warnings
|
||||
import os
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
from ofa_local.imagenet_classification.data_providers.base_provider import DataProvider
|
||||
from ofa_local.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
|
||||
from .metaloader import MetaImageNetDataset, EpisodeSampler, MetaDataLoader
|
||||
|
||||
|
||||
__all__ = ['ImagenetDataProvider']
|
||||
|
||||
|
||||
class ImagenetDataProvider(DataProvider):
|
||||
DEFAULT_PATH = '/dataset/imagenet'
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = 'None' if distort_color is None else distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
from ofa.utils.my_dataloader import MyDataLoader
|
||||
assert isinstance(self.image_size, list)
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size) # active resolution for test
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
|
||||
########################## modification ########################
|
||||
train_dataset = self.train_dataset(self.build_train_transform())
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset), valid_size)
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, True, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, True, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
# test_dataset = self.test_dataset(valid_transforms)
|
||||
test_dataset = self.meta_test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
# self.test = torch.utils.data.DataLoader(
|
||||
# test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
# )
|
||||
sampler = EpisodeSampler(
|
||||
max_way=1000, query=10, ylst=test_dataset.ylst)
|
||||
self.test = MetaDataLoader(dataset=test_dataset,
|
||||
sampler=sampler,
|
||||
batch_size=test_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'imagenet'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 1000
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = self.DEFAULT_PATH
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = os.path.expanduser('~/dataset/imagenet')
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
return datasets.ImageFolder(self.train_path, _transforms)
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
return datasets.ImageFolder(self.valid_path, _transforms)
|
||||
|
||||
def meta_test_dataset(self, _transforms):
|
||||
return MetaImageNetDataset('val', max_way=1000, query=10,
|
||||
dpath='/w14/dataset/ILSVRC2012', transform=_transforms)
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train')
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return os.path.join(self.save_path, 'val')
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
# random_resize_crop -> random_horizontal_flip
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
|
||||
# color augmentation (optional)
|
||||
color_transform = None
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting BN running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, True, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
@@ -0,0 +1,40 @@
|
||||
from .imagenet import ImagenetDataProvider
|
||||
from ofa_local.imagenet_classification.run_manager import RunConfig
|
||||
|
||||
|
||||
__all__ = ['ImagenetRunConfig']
|
||||
|
||||
|
||||
class ImagenetRunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='imagenet', train_batch_size=256, test_batch_size=500, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys=None,
|
||||
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224, **kwargs):
|
||||
super(ImagenetRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == ImagenetDataProvider.name():
|
||||
DataProviderClass = ImagenetDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
@@ -0,0 +1,210 @@
|
||||
from torch.utils.data.sampler import Sampler
|
||||
import os
|
||||
import random
|
||||
from PIL import Image
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import glob
|
||||
|
||||
|
||||
class RandCycleIter:
|
||||
'''
|
||||
Return data_list per class
|
||||
Shuffle the returning order after one epoch
|
||||
'''
|
||||
def __init__ (self, data, shuffle=True):
|
||||
self.data_list = list(data)
|
||||
self.length = len(self.data_list)
|
||||
self.i = self.length - 1
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__ (self):
|
||||
return self
|
||||
|
||||
def __next__ (self):
|
||||
self.i += 1
|
||||
|
||||
if self.i == self.length:
|
||||
self.i = 0
|
||||
if self.shuffle:
|
||||
random.shuffle(self.data_list)
|
||||
|
||||
return self.data_list[self.i]
|
||||
|
||||
|
||||
class EpisodeSampler(Sampler):
|
||||
def __init__(self, max_way, query, ylst):
|
||||
self.max_way = max_way
|
||||
self.query = query
|
||||
self.ylst = ylst
|
||||
# self.n_epi = n_epi
|
||||
|
||||
clswise_xidx = defaultdict(list)
|
||||
for i, y in enumerate(ylst):
|
||||
clswise_xidx[y].append(i)
|
||||
self.cws_xidx_iter = [RandCycleIter(cxidx, shuffle=True)
|
||||
for cxidx in clswise_xidx.values()]
|
||||
self.n_cls = len(clswise_xidx)
|
||||
|
||||
self.create_episode()
|
||||
|
||||
|
||||
def __iter__ (self):
|
||||
return self.get_index()
|
||||
|
||||
def __len__ (self):
|
||||
return self.get_len()
|
||||
|
||||
def create_episode(self):
|
||||
self.way = torch.randperm(int(self.max_way/10.0)-1)[0] * 10 + 10
|
||||
cls_lst = torch.sort(torch.randperm(self.max_way)[:self.way])[0]
|
||||
self.cls_itr = iter(cls_lst)
|
||||
self.cls_lst = cls_lst
|
||||
|
||||
def get_len(self):
|
||||
return self.way * self.query
|
||||
|
||||
def get_index(self):
|
||||
x_itr = self.cws_xidx_iter
|
||||
|
||||
i, j = 0, 0
|
||||
while i < self.query * self.way:
|
||||
if j >= self.query:
|
||||
j = 0
|
||||
if j == 0:
|
||||
cls_idx = next(self.cls_itr).item()
|
||||
bb = [x_itr[cls_idx]] * self.query
|
||||
didx = next(zip(*bb))
|
||||
yield didx[j]
|
||||
# yield (didx[j], self.way)
|
||||
|
||||
i += 1; j += 1
|
||||
|
||||
|
||||
class MetaImageNetDataset(Dataset):
|
||||
def __init__(self, mode='val',
|
||||
max_way=1000, query=10,
|
||||
dpath='/w14/dataset/ILSVRC2012', transform=None):
|
||||
self.dpath = dpath
|
||||
self.transform = transform
|
||||
self.mode = mode
|
||||
|
||||
self.max_way = max_way
|
||||
self.query = query
|
||||
classes, class_to_idx = self._find_classes(dpath+'/'+mode)
|
||||
self.classes, self.class_to_idx = classes, class_to_idx
|
||||
# self.class_folder_lst = \
|
||||
# glob.glob(dpath+'/'+mode+'/*')
|
||||
# ## sorting alphabetically
|
||||
# self.class_folder_lst = sorted(self.class_folder_lst)
|
||||
self.file_path_lst, self.ylst = [], []
|
||||
for cls in classes:
|
||||
xlst = glob.glob(dpath+'/'+mode+'/'+cls+'/*')
|
||||
self.file_path_lst += xlst[:self.query]
|
||||
y = class_to_idx[cls]
|
||||
self.ylst += [y] * len(xlst[:self.query])
|
||||
|
||||
# for y, cls in enumerate(self.class_folder_lst):
|
||||
# xlst = glob.glob(cls+'/*')
|
||||
# self.file_path_lst += xlst[:self.query]
|
||||
# self.ylst += [y] * len(xlst[:self.query])
|
||||
# # self.file_path_lst += [xlst[_] for _ in
|
||||
# # torch.randperm(len(xlst))[:self.query]]
|
||||
# # self.ylst += [cls.split('/')[-1]] * len(xlst)
|
||||
|
||||
self.way_idx = 0
|
||||
self.x_idx = 0
|
||||
self.way = 2
|
||||
self.cls_lst = None
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.way * self.query
|
||||
|
||||
def __getitem__(self, index):
|
||||
# if self.way != index[1]:
|
||||
# self.way = index[1]
|
||||
# index = index[0]
|
||||
|
||||
x = Image.open(
|
||||
self.file_path_lst[index]).convert('RGB')
|
||||
|
||||
if self.transform is not None:
|
||||
x = self.transform(x)
|
||||
cls_name = self.ylst[index]
|
||||
y = self.cls_lst.index(cls_name)
|
||||
# y = self.way_idx
|
||||
# self.x_idx += 1
|
||||
# if self.x_idx == self.query:
|
||||
# self.way_idx += 1
|
||||
# self.x_idx = 0
|
||||
# if self.way_idx == self.way:
|
||||
# self.way_idx = 0
|
||||
# self.x_idx = 0
|
||||
return x, y #, cls_name # y # cls_name #y
|
||||
|
||||
def _find_classes(self, dir: str):
|
||||
"""
|
||||
Finds the class folders in a dataset.
|
||||
|
||||
Args:
|
||||
dir (string): Root directory path.
|
||||
|
||||
Returns:
|
||||
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
|
||||
|
||||
Ensures:
|
||||
No class is a subdirectory of another.
|
||||
"""
|
||||
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
|
||||
classes.sort()
|
||||
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
|
||||
return classes, class_to_idx
|
||||
|
||||
|
||||
class MetaDataLoader(DataLoader):
|
||||
def __init__(self,
|
||||
dataset, sampler, batch_size, shuffle, num_workers):
|
||||
super(MetaDataLoader, self).__init__(
|
||||
dataset=dataset,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
num_workers=num_workers)
|
||||
|
||||
|
||||
def create_episode(self):
|
||||
self.sampler.create_episode()
|
||||
self.dataset.way = self.sampler.way
|
||||
self.dataset.cls_lst = self.sampler.cls_lst.tolist()
|
||||
|
||||
|
||||
def get_cls_idx(self):
|
||||
return self.sampler.cls_lst
|
||||
|
||||
|
||||
def get_loader(mode='val', way=10, query=10,
|
||||
n_epi=100, dpath='/w14/dataset/ILSVRC2012',
|
||||
transform=None):
|
||||
trans = get_transforms(mode)
|
||||
dataset = MetaImageNetDataset(mode, way, query, dpath, trans)
|
||||
sampler = EpisodeSampler(
|
||||
way, query, n_epi, dataset.ylst)
|
||||
dataset.way = sampler.way
|
||||
dataset.cls_lst = sampler.cls_lst
|
||||
loader = MetaDataLoader(dataset=dataset,
|
||||
sampler=sampler,
|
||||
batch_size=10,
|
||||
shuffle=False,
|
||||
num_workers=4)
|
||||
return loader
|
||||
|
||||
# trloader = get_loader()
|
||||
|
||||
# trloader.create_episode()
|
||||
# print(len(trloader))
|
||||
# print(trloader.dataset.way)
|
||||
# print(trloader.sampler.way)
|
||||
# for i, episode in enumerate(trloader, start=1):
|
||||
# print(episode[2])
|
||||
@@ -0,0 +1,302 @@
|
||||
######################################################################################
|
||||
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
import os
|
||||
import json
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.optim
|
||||
from tqdm import tqdm
|
||||
from utils import decode_ofa_mbv3_to_igraph
|
||||
from ofa_local.utils import get_net_info, cross_entropy_loss_with_soft_target, cross_entropy_with_label_smoothing
|
||||
from ofa_local.utils import AverageMeter, accuracy, write_log, mix_images, mix_labels, init_models
|
||||
|
||||
__all__ = ['RunManager']
|
||||
import torchvision.models as models
|
||||
|
||||
|
||||
class RunManager:
|
||||
|
||||
def __init__(self, path, args, net, run_config, init=True, measure_latency=None,
|
||||
no_gpu=False, data_loader=None, pp=None):
|
||||
self.path = path
|
||||
self.mode = args.model_name
|
||||
self.net = net
|
||||
self.run_config = run_config
|
||||
|
||||
self.best_acc = 0
|
||||
self.start_epoch = 0
|
||||
|
||||
os.makedirs(self.path, exist_ok=True)
|
||||
# dataloader
|
||||
if data_loader is not None:
|
||||
self.data_loader = data_loader
|
||||
cls_lst = self.data_loader.get_cls_idx()
|
||||
self.cls_lst = cls_lst
|
||||
else:
|
||||
self.data_loader = self.run_config.valid_loader
|
||||
self.data_loader.create_episode()
|
||||
cls_lst = self.data_loader.get_cls_idx()
|
||||
self.cls_lst = cls_lst
|
||||
|
||||
state_dict = self.net.classifier.state_dict()
|
||||
new_state_dict = {'weight': state_dict['linear.weight'][cls_lst],
|
||||
'bias': state_dict['linear.bias'][cls_lst]}
|
||||
|
||||
self.net.classifier = nn.Linear(1280, len(cls_lst), bias=True)
|
||||
self.net.classifier.load_state_dict(new_state_dict)
|
||||
|
||||
# move network to GPU if available
|
||||
if torch.cuda.is_available() and (not no_gpu):
|
||||
self.device = torch.device('cuda:0')
|
||||
self.net = self.net.to(self.device)
|
||||
cudnn.benchmark = True
|
||||
else:
|
||||
self.device = torch.device('cpu')
|
||||
|
||||
# net info
|
||||
net_info = get_net_info(
|
||||
self.net, self.run_config.data_provider.data_shape, measure_latency, False)
|
||||
self.net_info = net_info
|
||||
self.test_transform = self.run_config.data_provider.test.dataset.transform
|
||||
|
||||
# criterion
|
||||
if isinstance(self.run_config.mixup_alpha, float):
|
||||
self.train_criterion = cross_entropy_loss_with_soft_target
|
||||
elif self.run_config.label_smoothing > 0:
|
||||
self.train_criterion = \
|
||||
lambda pred, target: cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing)
|
||||
else:
|
||||
self.train_criterion = nn.CrossEntropyLoss()
|
||||
self.test_criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# optimizer
|
||||
if self.run_config.no_decay_keys:
|
||||
keys = self.run_config.no_decay_keys.split('#')
|
||||
net_params = [
|
||||
self.network.get_parameters(keys, mode='exclude'), # parameters with weight decay
|
||||
self.network.get_parameters(keys, mode='include'), # parameters without weight decay
|
||||
]
|
||||
else:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
net_params = self.network.weight_parameters()
|
||||
except Exception:
|
||||
net_params = []
|
||||
for param in self.network.parameters():
|
||||
if param.requires_grad:
|
||||
net_params.append(param)
|
||||
self.optimizer = self.run_config.build_optimizer(net_params)
|
||||
|
||||
self.net = torch.nn.DataParallel(self.net)
|
||||
|
||||
if self.mode == 'generator':
|
||||
# PP
|
||||
save_dir = f'{args.save_path}/predictor/model/ckpt_max_corr.pt'
|
||||
|
||||
self.acc_predictor = pp.to('cuda')
|
||||
self.acc_predictor.load_state_dict(torch.load(save_dir))
|
||||
self.acc_predictor = torch.nn.DataParallel(self.acc_predictor)
|
||||
model = models.resnet18(pretrained=True).eval()
|
||||
feature_extractor = torch.nn.Sequential(*list(model.children())[:-1]).to(self.device)
|
||||
self.feature_extractor = torch.nn.DataParallel(feature_extractor)
|
||||
|
||||
""" save path and log path """
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self.__dict__.get('_save_path', None) is None:
|
||||
save_path = os.path.join(self.path, 'checkpoint')
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
self.__dict__['_save_path'] = save_path
|
||||
return self.__dict__['_save_path']
|
||||
|
||||
@property
|
||||
def logs_path(self):
|
||||
if self.__dict__.get('_logs_path', None) is None:
|
||||
logs_path = os.path.join(self.path, 'logs')
|
||||
os.makedirs(logs_path, exist_ok=True)
|
||||
self.__dict__['_logs_path'] = logs_path
|
||||
return self.__dict__['_logs_path']
|
||||
|
||||
@property
|
||||
def network(self):
|
||||
return self.net.module if isinstance(self.net, nn.DataParallel) else self.net
|
||||
|
||||
def write_log(self, log_str, prefix='valid', should_print=True, mode='a'):
|
||||
write_log(self.logs_path, log_str, prefix, should_print, mode)
|
||||
|
||||
""" save and load models """
|
||||
|
||||
def save_model(self, checkpoint=None, is_best=False, model_name=None):
|
||||
if checkpoint is None:
|
||||
checkpoint = {'state_dict': self.network.state_dict()}
|
||||
|
||||
if model_name is None:
|
||||
model_name = 'checkpoint.pth.tar'
|
||||
|
||||
checkpoint['dataset'] = self.run_config.dataset # add `dataset` info to the checkpoint
|
||||
latest_fname = os.path.join(self.save_path, 'latest.txt')
|
||||
model_path = os.path.join(self.save_path, model_name)
|
||||
with open(latest_fname, 'w') as fout:
|
||||
fout.write(model_path + '\n')
|
||||
torch.save(checkpoint, model_path)
|
||||
|
||||
if is_best:
|
||||
best_path = os.path.join(self.save_path, 'model_best.pth.tar')
|
||||
torch.save({'state_dict': checkpoint['state_dict']}, best_path)
|
||||
|
||||
def load_model(self, model_fname=None):
|
||||
latest_fname = os.path.join(self.save_path, 'latest.txt')
|
||||
if model_fname is None and os.path.exists(latest_fname):
|
||||
with open(latest_fname, 'r') as fin:
|
||||
model_fname = fin.readline()
|
||||
if model_fname[-1] == '\n':
|
||||
model_fname = model_fname[:-1]
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if model_fname is None or not os.path.exists(model_fname):
|
||||
model_fname = '%s/checkpoint.pth.tar' % self.save_path
|
||||
with open(latest_fname, 'w') as fout:
|
||||
fout.write(model_fname + '\n')
|
||||
print("=> loading checkpoint '{}'".format(model_fname))
|
||||
checkpoint = torch.load(model_fname, map_location='cpu')
|
||||
except Exception:
|
||||
print('fail to load checkpoint from %s' % self.save_path)
|
||||
return {}
|
||||
|
||||
self.network.load_state_dict(checkpoint['state_dict'])
|
||||
if 'epoch' in checkpoint:
|
||||
self.start_epoch = checkpoint['epoch'] + 1
|
||||
if 'best_acc' in checkpoint:
|
||||
self.best_acc = checkpoint['best_acc']
|
||||
if 'optimizer' in checkpoint:
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
||||
print("=> loaded checkpoint '{}'".format(model_fname))
|
||||
return checkpoint
|
||||
|
||||
def save_config(self, extra_run_config=None, extra_net_config=None):
|
||||
""" dump run_config and net_config to the model_folder """
|
||||
run_save_path = os.path.join(self.path, 'run.config')
|
||||
if not os.path.isfile(run_save_path):
|
||||
run_config = self.run_config.config
|
||||
if extra_run_config is not None:
|
||||
run_config.update(extra_run_config)
|
||||
json.dump(run_config, open(run_save_path, 'w'), indent=4)
|
||||
print('Run configs dump to %s' % run_save_path)
|
||||
|
||||
try:
|
||||
net_save_path = os.path.join(self.path, 'net.config')
|
||||
net_config = self.network.config
|
||||
if extra_net_config is not None:
|
||||
net_config.update(extra_net_config)
|
||||
json.dump(net_config, open(net_save_path, 'w'), indent=4)
|
||||
print('Network configs dump to %s' % net_save_path)
|
||||
except Exception:
|
||||
print('%s do not support net config' % type(self.network))
|
||||
|
||||
""" metric related """
|
||||
|
||||
def get_metric_dict(self):
|
||||
return {
|
||||
'top1': AverageMeter(),
|
||||
'top5': AverageMeter(),
|
||||
}
|
||||
|
||||
def update_metric(self, metric_dict, output, labels):
|
||||
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||
metric_dict['top1'].update(acc1[0].item(), output.size(0))
|
||||
metric_dict['top5'].update(acc5[0].item(), output.size(0))
|
||||
|
||||
def get_metric_vals(self, metric_dict, return_dict=False):
|
||||
if return_dict:
|
||||
return {
|
||||
key: metric_dict[key].avg for key in metric_dict
|
||||
}
|
||||
else:
|
||||
return [metric_dict[key].avg for key in metric_dict]
|
||||
|
||||
def get_metric_names(self):
|
||||
return 'top1', 'top5'
|
||||
|
||||
""" train and test """
|
||||
def validate(self, epoch=0, is_test=False, run_str='', net=None,
|
||||
data_loader=None, no_logs=False, train_mode=False, net_setting=None):
|
||||
if net is None:
|
||||
net = self.net
|
||||
if not isinstance(net, nn.DataParallel):
|
||||
net = nn.DataParallel(net)
|
||||
|
||||
if data_loader is not None:
|
||||
self.data_loader = data_loader
|
||||
|
||||
if train_mode:
|
||||
net.train()
|
||||
else:
|
||||
net.eval()
|
||||
|
||||
losses = AverageMeter()
|
||||
metric_dict = self.get_metric_dict()
|
||||
|
||||
features_stack = []
|
||||
with torch.no_grad():
|
||||
with tqdm(total=len(self.data_loader),
|
||||
desc='Validate Epoch #{} {}'.format(epoch + 1, run_str), disable=no_logs) as t:
|
||||
for i, (images, labels) in enumerate(self.data_loader):
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
if self.mode == 'generator':
|
||||
features = self.feature_extractor(images).squeeze()
|
||||
features_stack.append(features)
|
||||
# compute output
|
||||
output = net(images)
|
||||
loss = self.test_criterion(output, labels)
|
||||
# measure accuracy and record loss
|
||||
self.update_metric(metric_dict, output, labels)
|
||||
|
||||
losses.update(loss.item(), images.size(0))
|
||||
t.set_postfix({
|
||||
'loss': losses.avg,
|
||||
**self.get_metric_vals(metric_dict, return_dict=True),
|
||||
'img_size': images.size(2),
|
||||
})
|
||||
t.update(1)
|
||||
|
||||
if self.mode == 'generator':
|
||||
features_stack = torch.cat(features_stack)
|
||||
igraph_g = decode_ofa_mbv3_to_igraph(net_setting)[0]
|
||||
D_mu = self.acc_predictor.module.set_encode(features_stack.unsqueeze(0).to('cuda'))
|
||||
G_mu = self.acc_predictor.module.graph_encode(igraph_g)
|
||||
pred_acc = self.acc_predictor.module.predict(D_mu.unsqueeze(0), G_mu).item()
|
||||
|
||||
return losses.avg, self.get_metric_vals(metric_dict), \
|
||||
pred_acc if self.mode == 'generator' else None
|
||||
|
||||
|
||||
def validate_all_resolution(self, epoch=0, is_test=False, net=None):
|
||||
if net is None:
|
||||
net = self.network
|
||||
if isinstance(self.run_config.data_provider.image_size, list):
|
||||
img_size_list, loss_list, top1_list, top5_list = [], [], [], []
|
||||
for img_size in self.run_config.data_provider.image_size:
|
||||
img_size_list.append(img_size)
|
||||
self.run_config.data_provider.assign_active_img_size(img_size)
|
||||
self.reset_running_statistics(net=net)
|
||||
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
|
||||
loss_list.append(loss)
|
||||
top1_list.append(top1)
|
||||
top5_list.append(top5)
|
||||
return img_size_list, loss_list, top1_list, top5_list
|
||||
else:
|
||||
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
|
||||
return [self.run_config.data_provider.active_img_size], [loss], [top1], [top5]
|
||||
|
||||
def reset_running_statistics(self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None):
|
||||
from ofa_local.imagenet_classification.elastic_nn.utils import set_running_statistics
|
||||
if net is None:
|
||||
net = self.network
|
||||
if data_loader is None:
|
||||
data_loader = self.run_config.random_sub_train_loader(subset_size, subset_batch_size)
|
||||
set_running_statistics(net, data_loader)
|
||||
@@ -0,0 +1,4 @@
|
||||
######################################################################################
|
||||
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
@@ -0,0 +1,401 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import math
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
# from timm.data.transforms import _pil_interp
|
||||
from timm.data.auto_augment import rand_augment_transform
|
||||
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
from torchvision.datasets.folder import default_loader
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
def make_dataset(dir, image_ids, targets):
|
||||
assert(len(image_ids) == len(targets))
|
||||
images = []
|
||||
dir = os.path.expanduser(dir)
|
||||
for i in range(len(image_ids)):
|
||||
item = (os.path.join(dir, 'data', 'images',
|
||||
'%s.jpg' % image_ids[i]), targets[i])
|
||||
images.append(item)
|
||||
return images
|
||||
|
||||
|
||||
def find_classes(classes_file):
|
||||
# read classes file, separating out image IDs and class names
|
||||
image_ids = []
|
||||
targets = []
|
||||
f = open(classes_file, 'r')
|
||||
for line in f:
|
||||
split_line = line.split(' ')
|
||||
image_ids.append(split_line[0])
|
||||
targets.append(' '.join(split_line[1:]))
|
||||
f.close()
|
||||
|
||||
# index class names
|
||||
classes = np.unique(targets)
|
||||
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
||||
targets = [class_to_idx[c] for c in targets]
|
||||
|
||||
return (image_ids, targets, classes, class_to_idx)
|
||||
|
||||
|
||||
class FGVCAircraft(torch.utils.data.Dataset):
|
||||
"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset.
|
||||
Args:
|
||||
root (string): Root directory path to dataset.
|
||||
class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
|
||||
to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
|
||||
transform (callable, optional): A function/transform that takes in a PIL image
|
||||
and returns a transformed version. E.g. ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
loader (callable, optional): A function to load an image given its path.
|
||||
download (bool, optional): If true, downloads the dataset from the internet and
|
||||
puts it in the root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
"""
|
||||
url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
|
||||
class_types = ('variant', 'family', 'manufacturer')
|
||||
splits = ('train', 'val', 'trainval', 'test')
|
||||
|
||||
def __init__(self, root, class_type='variant', split='train', transform=None,
|
||||
target_transform=None, loader=default_loader, download=False):
|
||||
if split not in self.splits:
|
||||
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
|
||||
split, ', '.join(self.splits),
|
||||
))
|
||||
if class_type not in self.class_types:
|
||||
raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
|
||||
class_type, ', '.join(self.class_types),
|
||||
))
|
||||
self.root = os.path.expanduser(root)
|
||||
self.class_type = class_type
|
||||
self.split = split
|
||||
self.classes_file = os.path.join(self.root, 'data',
|
||||
'images_%s_%s.txt' % (self.class_type, self.split))
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
||||
(image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file)
|
||||
samples = make_dataset(self.root, image_ids, targets)
|
||||
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
self.loader = loader
|
||||
|
||||
self.samples = samples
|
||||
self.classes = classes
|
||||
self.class_to_idx = class_to_idx
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
Returns:
|
||||
tuple: (sample, target) where target is class_index of the target class.
|
||||
"""
|
||||
|
||||
path, target = self.samples[index]
|
||||
sample = self.loader(path)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return sample, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __repr__(self):
|
||||
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
||||
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
|
||||
fmt_str += ' Root Location: {}\n'.format(self.root)
|
||||
tmp = ' Transforms (if any): '
|
||||
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||
tmp = ' Target Transforms (if any): '
|
||||
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||
return fmt_str
|
||||
|
||||
def _check_exists(self):
|
||||
return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
|
||||
os.path.exists(self.classes_file)
|
||||
|
||||
def download(self):
|
||||
"""Download the FGVC-Aircraft data if it doesn't exist already."""
|
||||
from six.moves import urllib
|
||||
import tarfile
|
||||
|
||||
if self._check_exists():
|
||||
return
|
||||
|
||||
# prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
|
||||
print('Downloading %s ... (may take a few minutes)' % self.url)
|
||||
|
||||
parent_dir = os.path.abspath(os.path.join(self.root, os.pardir))
|
||||
tar_name = self.url.rpartition('/')[-1]
|
||||
tar_path = os.path.join(parent_dir, tar_name)
|
||||
data = urllib.request.urlopen(self.url)
|
||||
|
||||
# download .tar.gz file
|
||||
with open(tar_path, 'wb') as f:
|
||||
f.write(data.read())
|
||||
|
||||
# extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b
|
||||
data_folder = tar_path.strip('.tar.gz')
|
||||
print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder))
|
||||
tar = tarfile.open(tar_path)
|
||||
tar.extractall(parent_dir)
|
||||
|
||||
# if necessary, rename data folder to self.root
|
||||
if not os.path.samefile(data_folder, self.root):
|
||||
print('Renaming %s to %s ...' % (data_folder, self.root))
|
||||
os.rename(data_folder, self.root)
|
||||
|
||||
# delete .tar.gz file
|
||||
print('Deleting %s ...' % tar_path)
|
||||
os.remove(tar_path)
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
class FGVCAircraftDataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.samples) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'aircraft'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 100
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/Aircraft' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/Aircraft' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
dataset = FGVCAircraft(
|
||||
root=self.train_path, split='trainval', download=True, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
dataset = FGVCAircraft(
|
||||
root=self.valid_path, split='test', download=True, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.48933587508932375, 0.5183537408957618, 0.5387914411673883],
|
||||
std=[0.22388883112804625, 0.21641635409388751, 0.24615605842636115])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
# if print_log:
|
||||
# print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
# (self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
# if self.distort_color == 'torch':
|
||||
# color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
# elif self.distort_color == 'tf':
|
||||
# color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
# else:
|
||||
# color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
img_size_min = min(image_size)
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
img_size_min = image_size
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
|
||||
aa_params = dict(
|
||||
translate_const=int(img_size_min * 0.45),
|
||||
img_mean=tuple([min(255, round(255 * x)) for x in [0.48933587508932375, 0.5183537408957618,
|
||||
0.5387914411673883]]),
|
||||
)
|
||||
aa_params['interpolation'] = transforms.Resize(image_size) # _pil_interp('bicubic')
|
||||
train_transforms += [rand_augment_transform(auto_augment, aa_params)]
|
||||
|
||||
# if color_transform is not None:
|
||||
# train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.samples)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data = FGVCAircraft(root='/mnt/datastore/Aircraft',
|
||||
split='trainval', download=True)
|
||||
print(len(data.classes))
|
||||
print(len(data.samples))
|
||||
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
Taken from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
|
||||
"""
|
||||
|
||||
from PIL import Image, ImageEnhance, ImageOps
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
class ImageNetPolicy(object):
|
||||
""" Randomly choose one of the best 24 Sub-policies on ImageNet.
|
||||
|
||||
Example:
|
||||
>>> policy = ImageNetPolicy()
|
||||
>>> transformed = policy(image)
|
||||
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> ImageNetPolicy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
|
||||
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
|
||||
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
|
||||
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
|
||||
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
|
||||
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
|
||||
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
|
||||
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
|
||||
]
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment ImageNet Policy"
|
||||
|
||||
|
||||
class CIFAR10Policy(object):
|
||||
""" Randomly choose one of the best 25 Sub-policies on CIFAR10.
|
||||
|
||||
Example:
|
||||
>>> policy = CIFAR10Policy()
|
||||
>>> transformed = policy(image)
|
||||
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> CIFAR10Policy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
|
||||
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
|
||||
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
|
||||
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
|
||||
|
||||
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
|
||||
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
|
||||
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
|
||||
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
|
||||
|
||||
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
|
||||
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
|
||||
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
|
||||
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
|
||||
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
|
||||
SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
|
||||
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
|
||||
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
|
||||
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
|
||||
]
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment CIFAR10 Policy"
|
||||
|
||||
|
||||
class SVHNPolicy(object):
|
||||
""" Randomly choose one of the best 25 Sub-policies on SVHN.
|
||||
|
||||
Example:
|
||||
>>> policy = SVHNPolicy()
|
||||
>>> transformed = policy(image)
|
||||
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> SVHNPolicy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
|
||||
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
|
||||
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
|
||||
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
|
||||
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
|
||||
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
|
||||
|
||||
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
|
||||
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
|
||||
]
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment SVHN Policy"
|
||||
|
||||
|
||||
class SubPolicy(object):
|
||||
def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
|
||||
ranges = {
|
||||
"shearX": np.linspace(0, 0.3, 10),
|
||||
"shearY": np.linspace(0, 0.3, 10),
|
||||
"translateX": np.linspace(0, 150 / 331, 10),
|
||||
"translateY": np.linspace(0, 150 / 331, 10),
|
||||
"rotate": np.linspace(0, 30, 10),
|
||||
"color": np.linspace(0.0, 0.9, 10),
|
||||
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
|
||||
"solarize": np.linspace(256, 0, 10),
|
||||
"contrast": np.linspace(0.0, 0.9, 10),
|
||||
"sharpness": np.linspace(0.0, 0.9, 10),
|
||||
"brightness": np.linspace(0.0, 0.9, 10),
|
||||
"autocontrast": [0] * 10,
|
||||
"equalize": [0] * 10,
|
||||
"invert": [0] * 10
|
||||
}
|
||||
|
||||
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
|
||||
def rotate_with_fill(img, magnitude):
|
||||
rot = img.convert("RGBA").rotate(magnitude)
|
||||
return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
|
||||
|
||||
func = {
|
||||
"shearX": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
|
||||
Image.BICUBIC, fillcolor=fillcolor),
|
||||
"shearY": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
|
||||
Image.BICUBIC, fillcolor=fillcolor),
|
||||
"translateX": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
|
||||
fillcolor=fillcolor),
|
||||
"translateY": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
|
||||
fillcolor=fillcolor),
|
||||
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
|
||||
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
|
||||
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
|
||||
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
|
||||
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
|
||||
"equalize": lambda img, magnitude: ImageOps.equalize(img),
|
||||
"invert": lambda img, magnitude: ImageOps.invert(img)
|
||||
}
|
||||
|
||||
self.p1 = p1
|
||||
self.operation1 = func[operation1]
|
||||
self.magnitude1 = ranges[operation1][magnitude_idx1]
|
||||
self.p2 = p2
|
||||
self.operation2 = func[operation2]
|
||||
self.magnitude2 = ranges[operation2][magnitude_idx2]
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p1: img = self.operation1(img, self.magnitude1)
|
||||
if random.random() < self.p2: img = self.operation2(img, self.magnitude2)
|
||||
return img
|
||||
@@ -0,0 +1,657 @@
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import torchvision
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
class CIFAR10DataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
|
||||
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
|
||||
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.data) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'cifar10'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 10
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/CIFAR' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/CIFAR' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
dataset = torchvision.datasets.CIFAR10(
|
||||
root=self.valid_path, train=True, download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
dataset = torchvision.datasets.CIFAR10(
|
||||
root=self.valid_path, train=False, download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
# return os.path.join(self.save_path, 'train')
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
# return os.path.join(self.save_path, 'val')
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.49139968, 0.48215827, 0.44653124], std=[0.24703233, 0.24348505, 0.26158768])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
else:
|
||||
color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.data)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
|
||||
|
||||
class CIFAR100DataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
|
||||
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
|
||||
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.data) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'cifar100'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 100
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/CIFAR' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/CIFAR' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
dataset = torchvision.datasets.CIFAR100(
|
||||
root=self.valid_path, train=True, download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
dataset = torchvision.datasets.CIFAR100(
|
||||
root=self.valid_path, train=False, download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
# return os.path.join(self.save_path, 'train')
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
# return os.path.join(self.save_path, 'val')
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.49139968, 0.48215827, 0.44653124], std=[0.24703233, 0.24348505, 0.26158768])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
else:
|
||||
color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.data)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
|
||||
|
||||
class CINIC10DataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
|
||||
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
|
||||
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.data) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'cinic10'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 10
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/CINIC10' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/CINIC10' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
dataset = torchvision.datasets.ImageFolder(self.train_path, transform=_transforms)
|
||||
# dataset = torchvision.datasets.CIFAR10(
|
||||
# root=self.valid_path, train=True, download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
dataset = torchvision.datasets.ImageFolder(self.valid_path, transform=_transforms)
|
||||
# dataset = torchvision.datasets.CIFAR10(
|
||||
# root=self.valid_path, train=False, download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train_and_valid')
|
||||
# return self.save_path
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return os.path.join(self.save_path, 'test')
|
||||
# return self.save_path
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.47889522, 0.47227842, 0.43047404], std=[0.24205776, 0.23828046, 0.25874835])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
else:
|
||||
color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.samples)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
@@ -0,0 +1,237 @@
|
||||
import os
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
from timm.data.transforms import _pil_interp
|
||||
from timm.data.auto_augment import rand_augment_transform
|
||||
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
class DTDDataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.samples) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'dtd'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 47
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/dtd' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/dtd' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train')
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return os.path.join(self.save_path, 'valid')
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.5329876098715876, 0.474260843249454, 0.42627281899380676],
|
||||
std=[0.26549755708788914, 0.25473554309855373, 0.2631728035662832])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
# if print_log:
|
||||
# print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
# (self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
# if self.distort_color == 'torch':
|
||||
# color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
# elif self.distort_color == 'tf':
|
||||
# color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
# else:
|
||||
# color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
img_size_min = min(image_size)
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
img_size_min = image_size
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
|
||||
aa_params = dict(
|
||||
translate_const=int(img_size_min * 0.45),
|
||||
img_mean=tuple([min(255, round(255 * x)) for x in [0.5329876098715876, 0.474260843249454,
|
||||
0.42627281899380676]]),
|
||||
)
|
||||
aa_params['interpolation'] = _pil_interp('bicubic')
|
||||
train_transforms += [rand_augment_transform(auto_augment, aa_params)]
|
||||
|
||||
# if color_transform is not None:
|
||||
# train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
# transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.Resize((image_size, image_size), interpolation=3),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.samples)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
@@ -0,0 +1,241 @@
|
||||
import warnings
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import PIL
|
||||
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
class Flowers102DataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=512, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
|
||||
# warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
weights = self.make_weights_for_balanced_classes(
|
||||
train_dataset.imgs, self.n_classes)
|
||||
weights = torch.DoubleTensor(weights)
|
||||
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
|
||||
|
||||
if valid_size is not None:
|
||||
raise NotImplementedError("validation dataset not yet implemented")
|
||||
# valid_dataset = self.valid_dataset(valid_transforms)
|
||||
|
||||
# self.train = train_loader_class(
|
||||
# train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
# num_workers=n_worker, pin_memory=True)
|
||||
# self.valid = torch.utils.data.DataLoader(
|
||||
# valid_dataset, batch_size=test_batch_size,
|
||||
# num_workers=n_worker, pin_memory=True)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'flowers102'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 102
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
# self._save_path = '/mnt/datastore/Oxford102Flowers' # home server
|
||||
self._save_path = '/mnt/datastore/Flowers102' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
# self._save_path = '/mnt/datastore/Oxford102Flowers' # home server
|
||||
self._save_path = '/mnt/datastore/Flowers102' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
return dataset
|
||||
|
||||
# def valid_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
# return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.test_path, _transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train')
|
||||
|
||||
# @property
|
||||
# def valid_path(self):
|
||||
# return os.path.join(self.save_path, 'train')
|
||||
|
||||
@property
|
||||
def test_path(self):
|
||||
return os.path.join(self.save_path, 'test')
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.5178361839861569, 0.4106749456881299, 0.32864167836880803],
|
||||
std=[0.2972239085211309, 0.24976049135203868, 0.28533308036347665])
|
||||
|
||||
@staticmethod
|
||||
def make_weights_for_balanced_classes(images, nclasses):
|
||||
count = [0] * nclasses
|
||||
|
||||
# Counts per label
|
||||
for item in images:
|
||||
count[item[1]] += 1
|
||||
|
||||
weight_per_class = [0.] * nclasses
|
||||
|
||||
# Total number of images.
|
||||
N = float(sum(count))
|
||||
|
||||
# super-sample the smaller classes.
|
||||
for i in range(nclasses):
|
||||
weight_per_class[i] = N / float(count[i])
|
||||
|
||||
weight = [0] * len(images)
|
||||
|
||||
# Calculate a weight per image.
|
||||
for idx, val in enumerate(images):
|
||||
weight[idx] = weight_per_class[val[1]]
|
||||
|
||||
return weight
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
else:
|
||||
color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
train_transforms = [
|
||||
transforms.RandomAffine(
|
||||
45, translate=(0.4, 0.4), scale=(0.75, 1.5), shear=None, resample=PIL.Image.BILINEAR, fillcolor=0),
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
# transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.samples)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
@@ -0,0 +1,225 @@
|
||||
import warnings
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
class ImagenetDataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.samples) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'imagenet'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 1000
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
# self._save_path = '/dataset/imagenet'
|
||||
# self._save_path = '/usr/local/soft/temp-datastore/ILSVRC2012' # servers
|
||||
self._save_path = '/mnt/datastore/ILSVRC2012' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
# self._save_path = os.path.expanduser('~/dataset/imagenet')
|
||||
# self._save_path = os.path.expanduser('/usr/local/soft/temp-datastore/ILSVRC2012')
|
||||
self._save_path = '/mnt/datastore/ILSVRC2012' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train')
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return os.path.join(self.save_path, 'val')
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
else:
|
||||
color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.samples)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
@@ -0,0 +1,237 @@
|
||||
import os
|
||||
import math
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
# from timm.data.transforms import _pil_interp
|
||||
from timm.data.auto_augment import rand_augment_transform
|
||||
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
class OxfordIIITPetsDataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.samples) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'pets'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 37
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/Oxford-IIITPets' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/Oxford-IIITPets' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train')
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return os.path.join(self.save_path, 'valid')
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.4828895122298728, 0.4448394893850807, 0.39566558230789783],
|
||||
std=[0.25925664613996574, 0.2532760018681693, 0.25981017205097917])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
# if print_log:
|
||||
# print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
# (self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
# if self.distort_color == 'torch':
|
||||
# color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
# elif self.distort_color == 'tf':
|
||||
# color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
# else:
|
||||
# color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
img_size_min = min(image_size)
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
img_size_min = image_size
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
|
||||
aa_params = dict(
|
||||
translate_const=int(img_size_min * 0.45),
|
||||
img_mean=tuple([min(255, round(255 * x)) for x in [0.4828895122298728, 0.4448394893850807,
|
||||
0.39566558230789783]]),
|
||||
)
|
||||
aa_params['interpolation'] = transforms.Resize(image_size) # _pil_interp('bicubic')
|
||||
train_transforms += [rand_augment_transform(auto_augment, aa_params)]
|
||||
|
||||
# if color_transform is not None:
|
||||
# train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.samples)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
from glob import glob
|
||||
from torch.utils.data.dataset import Dataset
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def load_image(filename):
|
||||
img = Image.open(filename)
|
||||
img = img.convert('RGB')
|
||||
return img
|
||||
|
||||
|
||||
class PetDataset(Dataset):
|
||||
def __init__(self, root, train=True, num_cl=37, val_split=0.15, transforms=None):
|
||||
pt_name = os.path.join(root, '{}{}.pth'.format('train' if train else 'test',
|
||||
int(100 * (1 - val_split)) if train else int(
|
||||
100 * val_split)))
|
||||
if not os.path.exists(pt_name):
|
||||
filenames = glob(os.path.join(root, 'images') + '/*.jpg')
|
||||
classes = set()
|
||||
|
||||
data = []
|
||||
labels = []
|
||||
|
||||
for image in filenames:
|
||||
class_name = image.rsplit("/", 1)[1].rsplit('_', 1)[0]
|
||||
classes.add(class_name)
|
||||
img = load_image(image)
|
||||
|
||||
data.append(img)
|
||||
labels.append(class_name)
|
||||
|
||||
# convert classnames to indices
|
||||
class2idx = {cl: idx for idx, cl in enumerate(classes)}
|
||||
labels = torch.Tensor(list(map(lambda x: class2idx[x], labels))).long()
|
||||
data = list(zip(data, labels))
|
||||
|
||||
class_values = [[] for x in range(num_cl)]
|
||||
|
||||
# create arrays for each class type
|
||||
for d in data:
|
||||
class_values[d[1].item()].append(d)
|
||||
|
||||
train_data = []
|
||||
val_data = []
|
||||
|
||||
for class_dp in class_values:
|
||||
split_idx = int(len(class_dp) * (1 - val_split))
|
||||
train_data += class_dp[:split_idx]
|
||||
val_data += class_dp[split_idx:]
|
||||
torch.save(train_data, os.path.join(root, 'train{}.pth'.format(int(100 * (1 - val_split)))))
|
||||
torch.save(val_data, os.path.join(root, 'test{}.pth'.format(int(100 * val_split))))
|
||||
|
||||
self.data = torch.load(pt_name)
|
||||
self.len = len(self.data)
|
||||
self.transform = transforms
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, label = self.data[index]
|
||||
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
|
||||
return img, label
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import torchvision
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
class STL10DataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
|
||||
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
|
||||
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.data) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'stl10'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 10
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/STL10' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/STL10' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
dataset = torchvision.datasets.STL10(
|
||||
root=self.valid_path, split='train', download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
dataset = torchvision.datasets.STL10(
|
||||
root=self.valid_path, split='test', download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
# return os.path.join(self.save_path, 'train')
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
# return os.path.join(self.save_path, 'val')
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.44671097, 0.4398105, 0.4066468],
|
||||
std=[0.2603405, 0.25657743, 0.27126738])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
else:
|
||||
color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.data)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
@@ -0,0 +1,4 @@
|
||||
from ofa.imagenet_codebase.networks.proxyless_nets import ProxylessNASNets, proxyless_base, MobileNetV2
|
||||
from ofa.imagenet_codebase.networks.mobilenet_v3 import MobileNetV3, MobileNetV3Large
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.networks.nsganetv2 import NSGANetV2
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
from timm.models.layers import drop_path
|
||||
from ofa.imagenet_codebase.modules.layers import *
|
||||
from ofa.imagenet_codebase.networks import MobileNetV3
|
||||
|
||||
|
||||
class MobileInvertedResidualBlock(MyModule):
|
||||
"""
|
||||
Modified from https://github.com/mit-han-lab/once-for-all/blob/master/ofa/
|
||||
imagenet_codebase/networks/proxyless_nets.py to include drop path in training
|
||||
|
||||
"""
|
||||
def __init__(self, mobile_inverted_conv, shortcut, drop_connect_rate=0.0):
|
||||
super(MobileInvertedResidualBlock, self).__init__()
|
||||
|
||||
self.mobile_inverted_conv = mobile_inverted_conv
|
||||
self.shortcut = shortcut
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
|
||||
def forward(self, x):
|
||||
if self.mobile_inverted_conv is None or isinstance(self.mobile_inverted_conv, ZeroLayer):
|
||||
res = x
|
||||
elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer):
|
||||
res = self.mobile_inverted_conv(x)
|
||||
else:
|
||||
# res = self.mobile_inverted_conv(x) + self.shortcut(x)
|
||||
res = self.mobile_inverted_conv(x)
|
||||
|
||||
if self.drop_connect_rate > 0.:
|
||||
res = drop_path(res, drop_prob=self.drop_connect_rate, training=self.training)
|
||||
|
||||
res += self.shortcut(x)
|
||||
|
||||
return res
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
return '(%s, %s)' % (
|
||||
self.mobile_inverted_conv.module_str if self.mobile_inverted_conv is not None else None,
|
||||
self.shortcut.module_str if self.shortcut is not None else None
|
||||
)
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': MobileInvertedResidualBlock.__name__,
|
||||
'mobile_inverted_conv': self.mobile_inverted_conv.config if self.mobile_inverted_conv is not None else None,
|
||||
'shortcut': self.shortcut.config if self.shortcut is not None else None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
mobile_inverted_conv = set_layer_from_config(config['mobile_inverted_conv'])
|
||||
shortcut = set_layer_from_config(config['shortcut'])
|
||||
return MobileInvertedResidualBlock(
|
||||
mobile_inverted_conv, shortcut, drop_connect_rate=config['drop_connect_rate'])
|
||||
|
||||
|
||||
class NSGANetV2(MobileNetV3):
|
||||
"""
|
||||
Modified from https://github.com/mit-han-lab/once-for-all/blob/master/ofa/
|
||||
imagenet_codebase/networks/mobilenet_v3.py to include drop path in training
|
||||
and option to reset classification layer
|
||||
"""
|
||||
@staticmethod
|
||||
def build_from_config(config, drop_connect_rate=0.0):
|
||||
first_conv = set_layer_from_config(config['first_conv'])
|
||||
final_expand_layer = set_layer_from_config(config['final_expand_layer'])
|
||||
feature_mix_layer = set_layer_from_config(config['feature_mix_layer'])
|
||||
classifier = set_layer_from_config(config['classifier'])
|
||||
|
||||
blocks = []
|
||||
for block_idx, block_config in enumerate(config['blocks']):
|
||||
block_config['drop_connect_rate'] = drop_connect_rate * block_idx / len(config['blocks'])
|
||||
blocks.append(MobileInvertedResidualBlock.build_from_config(block_config))
|
||||
|
||||
net = MobileNetV3(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
|
||||
if 'bn' in config:
|
||||
net.set_bn_param(**config['bn'])
|
||||
else:
|
||||
net.set_bn_param(momentum=0.1, eps=1e-3)
|
||||
|
||||
return net
|
||||
|
||||
def zero_last_gamma(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, MobileInvertedResidualBlock):
|
||||
if isinstance(m.mobile_inverted_conv, MBInvertedConvLayer) and isinstance(m.shortcut, IdentityLayer):
|
||||
m.mobile_inverted_conv.point_linear.bn.weight.data.zero_()
|
||||
|
||||
@staticmethod
|
||||
def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate):
|
||||
# first conv layer
|
||||
first_conv = ConvLayer(
|
||||
3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='h_swish', ops_order='weight_bn_act'
|
||||
)
|
||||
# build mobile blocks
|
||||
feature_dim = input_channel
|
||||
blocks = []
|
||||
for stage_id, block_config_list in cfg.items():
|
||||
for k, mid_channel, out_channel, use_se, act_func, stride, expand_ratio in block_config_list:
|
||||
mb_conv = MBInvertedConvLayer(
|
||||
feature_dim, out_channel, k, stride, expand_ratio, mid_channel, act_func, use_se
|
||||
)
|
||||
if stride == 1 and out_channel == feature_dim:
|
||||
shortcut = IdentityLayer(out_channel, out_channel)
|
||||
else:
|
||||
shortcut = None
|
||||
blocks.append(MobileInvertedResidualBlock(mb_conv, shortcut))
|
||||
feature_dim = out_channel
|
||||
# final expand layer
|
||||
final_expand_layer = ConvLayer(
|
||||
feature_dim, feature_dim * 6, kernel_size=1, use_bn=True, act_func='h_swish', ops_order='weight_bn_act',
|
||||
)
|
||||
feature_dim = feature_dim * 6
|
||||
# feature mix layer
|
||||
feature_mix_layer = ConvLayer(
|
||||
feature_dim, last_channel, kernel_size=1, bias=False, use_bn=False, act_func='h_swish',
|
||||
)
|
||||
# classifier
|
||||
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
||||
|
||||
return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
|
||||
|
||||
@staticmethod
|
||||
def reset_classifier(model, last_channel, n_classes, dropout_rate=0.0):
|
||||
model.classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
||||
@@ -0,0 +1,309 @@
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.imagenet import *
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.cifar import *
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.pets import *
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.aircraft import *
|
||||
|
||||
from ofa.imagenet_codebase.run_manager.run_manager import *
|
||||
|
||||
|
||||
class ImagenetRunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=1, init_lr=1e-4, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='imagenet', train_batch_size=128, test_batch_size=512, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None,
|
||||
model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
|
||||
data_path='/mnt/datastore/ILSVRC2012',
|
||||
**kwargs):
|
||||
super(ImagenetRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.imagenet_data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == ImagenetDataProvider.name():
|
||||
DataProviderClass = ImagenetDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.imagenet_data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class CIFARRunConfig(RunConfig):
|
||||
def __init__(self, n_epochs=5, init_lr=0.01, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='cifar10', train_batch_size=96, test_batch_size=256, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None,
|
||||
model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224,
|
||||
data_path='/mnt/datastore/CIFAR',
|
||||
**kwargs):
|
||||
super(CIFARRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.cifar_data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == CIFAR10DataProvider.name():
|
||||
DataProviderClass = CIFAR10DataProvider
|
||||
elif self.dataset == CIFAR100DataProvider.name():
|
||||
DataProviderClass = CIFAR100DataProvider
|
||||
elif self.dataset == CINIC10DataProvider.name():
|
||||
DataProviderClass = CINIC10DataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.cifar_data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class Flowers102RunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=3, init_lr=1e-2, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='flowers102', train_batch_size=32, test_batch_size=250, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None,
|
||||
model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=4, resize_scale=0.08, distort_color=None, image_size=224,
|
||||
data_path='/mnt/datastore/Flowers102',
|
||||
**kwargs):
|
||||
super(Flowers102RunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.flowers102_data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == Flowers102DataProvider.name():
|
||||
DataProviderClass = Flowers102DataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.flowers102_data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class STL10RunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=5, init_lr=1e-2, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='stl10', train_batch_size=96, test_batch_size=256, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None,
|
||||
model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=4, resize_scale=0.08, distort_color=None, image_size=224,
|
||||
data_path='/mnt/datastore/STL10',
|
||||
**kwargs):
|
||||
super(STL10RunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.stl10_data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == STL10DataProvider.name():
|
||||
DataProviderClass = STL10DataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.stl10_data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class DTDRunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='dtd', train_batch_size=32, test_batch_size=250, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
|
||||
data_path='/mnt/datastore/dtd',
|
||||
**kwargs):
|
||||
super(DTDRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == DTDDataProvider.name():
|
||||
DataProviderClass = DTDDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class PetsRunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='pets', train_batch_size=32, test_batch_size=250, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None,
|
||||
model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
|
||||
data_path='/mnt/datastore/Oxford-IIITPets',
|
||||
**kwargs):
|
||||
super(PetsRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.imagenet_data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == OxfordIIITPetsDataProvider.name():
|
||||
DataProviderClass = OxfordIIITPetsDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.imagenet_data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class AircraftRunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='aircraft', train_batch_size=32, test_batch_size=250, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None,
|
||||
model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
|
||||
data_path='/mnt/datastore/Aircraft',
|
||||
**kwargs):
|
||||
super(AircraftRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == FGVCAircraftDataProvider.name():
|
||||
DataProviderClass = FGVCAircraftDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
def get_run_config(**kwargs):
|
||||
if kwargs['dataset'] == 'imagenet':
|
||||
run_config = ImagenetRunConfig(**kwargs)
|
||||
elif kwargs['dataset'].startswith('cifar') or kwargs['dataset'].startswith('cinic'):
|
||||
run_config = CIFARRunConfig(**kwargs)
|
||||
elif kwargs['dataset'] == 'flowers102':
|
||||
run_config = Flowers102RunConfig(**kwargs)
|
||||
elif kwargs['dataset'] == 'stl10':
|
||||
run_config = STL10RunConfig(**kwargs)
|
||||
elif kwargs['dataset'] == 'dtd':
|
||||
run_config = DTDRunConfig(**kwargs)
|
||||
elif kwargs['dataset'] == 'pets':
|
||||
run_config = PetsRunConfig(**kwargs)
|
||||
elif kwargs['dataset'] == 'aircraft':
|
||||
run_config = AircraftRunConfig(**kwargs)
|
||||
elif kwargs['dataset'] == 'aircraft100':
|
||||
run_config = AircraftRunConfig(**kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return run_config
|
||||
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
import torchvision.utils
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.aircraft import FGVCAircraft
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.pets2 import PetDataset
|
||||
import torch.utils.data as Data
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.autoaugment import CIFAR10Policy
|
||||
|
||||
|
||||
def get_dataset(data_name, batch_size, data_path, num_workers,
|
||||
img_size, autoaugment, cutout, cutout_length):
|
||||
num_class_dict = {
|
||||
'cifar100': 100,
|
||||
'cifar10': 10,
|
||||
'mnist': 10,
|
||||
'aircraft': 100,
|
||||
'svhn': 10,
|
||||
'pets': 37
|
||||
}
|
||||
# 'aircraft30': 30,
|
||||
# 'aircraft100': 100,
|
||||
|
||||
train_transform, valid_transform = _data_transforms(
|
||||
data_name, img_size, autoaugment, cutout, cutout_length)
|
||||
if data_name == 'cifar100':
|
||||
train_data = torchvision.datasets.CIFAR100(
|
||||
root=data_path, train=True, download=True, transform=train_transform)
|
||||
valid_data = torchvision.datasets.CIFAR100(
|
||||
root=data_path, train=False, download=True, transform=valid_transform)
|
||||
elif data_name == 'cifar10':
|
||||
train_data = torchvision.datasets.CIFAR10(
|
||||
root=data_path, train=True, download=True, transform=train_transform)
|
||||
valid_data = torchvision.datasets.CIFAR10(
|
||||
root=data_path, train=False, download=True, transform=valid_transform)
|
||||
elif data_name.startswith('aircraft'):
|
||||
print(data_path)
|
||||
if 'aircraft100' in data_path:
|
||||
data_path = data_path.replace('aircraft100', 'aircraft/fgvc-aircraft-2013b')
|
||||
else:
|
||||
data_path = data_path.replace('aircraft', 'aircraft/fgvc-aircraft-2013b')
|
||||
train_data = FGVCAircraft(data_path, class_type='variant', split='trainval',
|
||||
transform=train_transform, download=True)
|
||||
valid_data = FGVCAircraft(data_path, class_type='variant', split='test',
|
||||
transform=valid_transform, download=True)
|
||||
elif data_name.startswith('pets'):
|
||||
train_data = PetDataset(data_path, train=True, num_cl=37,
|
||||
val_split=0.15, transforms=train_transform)
|
||||
valid_data = PetDataset(data_path, train=False, num_cl=37,
|
||||
val_split=0.15, transforms=valid_transform)
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
train_queue = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=batch_size, shuffle=True, pin_memory=True,
|
||||
num_workers=num_workers)
|
||||
|
||||
valid_queue = torch.utils.data.DataLoader(
|
||||
valid_data, batch_size=200, shuffle=False, pin_memory=True,
|
||||
num_workers=num_workers)
|
||||
|
||||
return train_queue, valid_queue, num_class_dict[data_name]
|
||||
|
||||
|
||||
|
||||
class Cutout(object):
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
|
||||
def __call__(self, img):
|
||||
h, w = img.size(1), img.size(2)
|
||||
mask = np.ones((h, w), np.float32)
|
||||
y = np.random.randint(h)
|
||||
x = np.random.randint(w)
|
||||
|
||||
y1 = np.clip(y - self.length // 2, 0, h)
|
||||
y2 = np.clip(y + self.length // 2, 0, h)
|
||||
x1 = np.clip(x - self.length // 2, 0, w)
|
||||
x2 = np.clip(x + self.length // 2, 0, w)
|
||||
|
||||
mask[y1: y2, x1: x2] = 0.
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.expand_as(img)
|
||||
img *= mask
|
||||
return img
|
||||
|
||||
|
||||
def _data_transforms(data_name, img_size, autoaugment, cutout, cutout_length):
|
||||
if 'cifar' in data_name:
|
||||
norm_mean = [0.49139968, 0.48215827, 0.44653124]
|
||||
norm_std = [0.24703233, 0.24348505, 0.26158768]
|
||||
elif 'aircraft' in data_name:
|
||||
norm_mean = [0.48933587508932375, 0.5183537408957618, 0.5387914411673883]
|
||||
norm_std = [0.22388883112804625, 0.21641635409388751, 0.24615605842636115]
|
||||
elif 'pets' in data_name:
|
||||
norm_mean = [0.4828895122298728, 0.4448394893850807, 0.39566558230789783]
|
||||
norm_std = [0.25925664613996574, 0.2532760018681693, 0.25981017205097917]
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
train_transform = transforms.Compose([
|
||||
transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC), # BICUBIC interpolation
|
||||
transforms.RandomHorizontalFlip(),
|
||||
])
|
||||
|
||||
if autoaugment:
|
||||
train_transform.transforms.append(CIFAR10Policy())
|
||||
|
||||
train_transform.transforms.append(transforms.ToTensor())
|
||||
|
||||
if cutout:
|
||||
train_transform.transforms.append(Cutout(cutout_length))
|
||||
|
||||
train_transform.transforms.append(transforms.Normalize(norm_mean, norm_std))
|
||||
|
||||
valid_transform = transforms.Compose([
|
||||
transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC), # BICUBIC interpolation
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(norm_mean, norm_std),
|
||||
])
|
||||
return train_transform, valid_transform
|
||||
@@ -0,0 +1,233 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import sys
|
||||
import transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.eval_utils
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.networks import NSGANetV2
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.run_manager import get_run_config
|
||||
from ofa.elastic_nn.networks import OFAMobileNetV3
|
||||
from ofa.imagenet_codebase.run_manager import RunManager
|
||||
from ofa.elastic_nn.modules.dynamic_op import DynamicSeparableConv2d
|
||||
from torchprofile import profile_macs
|
||||
import copy
|
||||
import json
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
DynamicSeparableConv2d.KERNEL_TRANSFORM_MODE = 1
|
||||
|
||||
|
||||
class ArchManager:
|
||||
def __init__(self):
|
||||
self.num_blocks = 20
|
||||
self.num_stages = 5
|
||||
self.kernel_sizes = [3, 5, 7]
|
||||
self.expand_ratios = [3, 4, 6]
|
||||
self.depths = [2, 3, 4]
|
||||
self.resolutions = [160, 176, 192, 208, 224]
|
||||
|
||||
def random_sample(self):
|
||||
sample = {}
|
||||
d = []
|
||||
e = []
|
||||
ks = []
|
||||
for i in range(self.num_stages):
|
||||
d.append(random.choice(self.depths))
|
||||
|
||||
for i in range(self.num_blocks):
|
||||
e.append(random.choice(self.expand_ratios))
|
||||
ks.append(random.choice(self.kernel_sizes))
|
||||
|
||||
sample = {
|
||||
'wid': None,
|
||||
'ks': ks,
|
||||
'e': e,
|
||||
'd': d,
|
||||
'r': [random.choice(self.resolutions)]
|
||||
}
|
||||
|
||||
return sample
|
||||
|
||||
def random_resample(self, sample, i):
|
||||
assert i >= 0 and i < self.num_blocks
|
||||
sample['ks'][i] = random.choice(self.kernel_sizes)
|
||||
sample['e'][i] = random.choice(self.expand_ratios)
|
||||
|
||||
def random_resample_depth(self, sample, i):
|
||||
assert i >= 0 and i < self.num_stages
|
||||
sample['d'][i] = random.choice(self.depths)
|
||||
|
||||
def random_resample_resolution(self, sample):
|
||||
sample['r'][0] = random.choice(self.resolutions)
|
||||
|
||||
|
||||
def parse_string_list(string):
|
||||
if isinstance(string, str):
|
||||
# convert '[5 5 5 7 7 7 3 3 7 7 7 3 3]' to [5, 5, 5, 7, 7, 7, 3, 3, 7, 7, 7, 3, 3]
|
||||
return list(map(int, string[1:-1].split()))
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def pad_none(x, depth, max_depth):
|
||||
new_x, counter = [], 0
|
||||
for d in depth:
|
||||
for _ in range(d):
|
||||
new_x.append(x[counter])
|
||||
counter += 1
|
||||
if d < max_depth:
|
||||
new_x += [None] * (max_depth - d)
|
||||
return new_x
|
||||
|
||||
|
||||
def get_net_info(net, data_shape, measure_latency=None, print_info=True, clean=False, lut=None):
|
||||
net_info = eval_utils.get_net_info(
|
||||
net, data_shape, measure_latency, print_info=print_info, clean=clean, lut=lut)
|
||||
|
||||
gpu_latency, cpu_latency = None, None
|
||||
for k in net_info.keys():
|
||||
if 'gpu' in k:
|
||||
gpu_latency = np.round(net_info[k]['val'], 2)
|
||||
if 'cpu' in k:
|
||||
cpu_latency = np.round(net_info[k]['val'], 2)
|
||||
|
||||
return {
|
||||
'params': np.round(net_info['params'] / 1e6, 2),
|
||||
'flops': np.round(net_info['flops'] / 1e6, 2),
|
||||
'gpu': gpu_latency, 'cpu': cpu_latency
|
||||
}
|
||||
|
||||
|
||||
def validate_config(config, max_depth=4):
|
||||
kernel_size, exp_ratio, depth = config['ks'], config['e'], config['d']
|
||||
|
||||
if isinstance(kernel_size, str): kernel_size = parse_string_list(kernel_size)
|
||||
if isinstance(exp_ratio, str): exp_ratio = parse_string_list(exp_ratio)
|
||||
if isinstance(depth, str): depth = parse_string_list(depth)
|
||||
|
||||
assert (isinstance(kernel_size, list) or isinstance(kernel_size, int))
|
||||
assert (isinstance(exp_ratio, list) or isinstance(exp_ratio, int))
|
||||
assert isinstance(depth, list)
|
||||
|
||||
if len(kernel_size) < len(depth) * max_depth:
|
||||
kernel_size = pad_none(kernel_size, depth, max_depth)
|
||||
if len(exp_ratio) < len(depth) * max_depth:
|
||||
exp_ratio = pad_none(exp_ratio, depth, max_depth)
|
||||
|
||||
# return {'ks': kernel_size, 'e': exp_ratio, 'd': depth, 'w': config['w']}
|
||||
return {'ks': kernel_size, 'e': exp_ratio, 'd': depth}
|
||||
|
||||
|
||||
def set_nas_test_dataset(path, test_data_name, max_img):
|
||||
if not test_data_name in ['mnist', 'svhn', 'cifar10',
|
||||
'cifar100', 'aircraft', 'pets']: raise ValueError(test_data_name)
|
||||
|
||||
dpath = path
|
||||
num_cls = 10 # mnist, svhn, cifar10
|
||||
if test_data_name in ['cifar100', 'aircraft']:
|
||||
num_cls = 100
|
||||
elif test_data_name == 'pets':
|
||||
num_cls = 37
|
||||
|
||||
x = torch.load(dpath + f'/{test_data_name}bylabel')
|
||||
img_per_cls = min(int(max_img / num_cls), 20)
|
||||
return x, img_per_cls, num_cls
|
||||
|
||||
|
||||
class OFAEvaluator:
|
||||
""" based on OnceForAll supernet taken from https://github.com/mit-han-lab/once-for-all """
|
||||
|
||||
def __init__(self, num_gen_arch, img_size, drop_path,
|
||||
n_classes=1000,
|
||||
model_path=None,
|
||||
kernel_size=None, exp_ratio=None, depth=None):
|
||||
# default configurations
|
||||
self.kernel_size = [3, 5, 7] if kernel_size is None else kernel_size # depth-wise conv kernel size
|
||||
self.exp_ratio = [3, 4, 6] if exp_ratio is None else exp_ratio # expansion rate
|
||||
self.depth = [2, 3, 4] if depth is None else depth # number of MB block repetition
|
||||
|
||||
if 'w1.0' in model_path:
|
||||
self.width_mult = 1.0
|
||||
elif 'w1.2' in model_path:
|
||||
self.width_mult = 1.2
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self.engine = OFAMobileNetV3(
|
||||
n_classes=n_classes,
|
||||
dropout_rate=0, width_mult_list=self.width_mult, ks_list=self.kernel_size,
|
||||
expand_ratio_list=self.exp_ratio, depth_list=self.depth)
|
||||
|
||||
|
||||
init = torch.load(model_path, map_location='cpu')['state_dict']
|
||||
self.engine.load_weights_from_net(init)
|
||||
print(f'load {model_path}...')
|
||||
|
||||
## metad2a
|
||||
self.arch_manager = ArchManager()
|
||||
self.num_gen_arch = num_gen_arch
|
||||
|
||||
|
||||
def sample_random_architecture(self):
|
||||
sampled_architecture = self.arch_manager.random_sample()
|
||||
return sampled_architecture
|
||||
|
||||
def get_architecture(self, bound=None):
|
||||
g_lst, pred_acc_lst, x_lst = [], [], []
|
||||
searched_g, max_pred_acc = None, 0
|
||||
|
||||
with torch.no_grad():
|
||||
for n in range(self.num_gen_arch):
|
||||
file_acc = self.lines[n].split()[0]
|
||||
g_dict = ' '.join(self.lines[n].split())
|
||||
g = json.loads(g_dict.replace("'", "\""))
|
||||
|
||||
if bound is not None:
|
||||
subnet, config = self.sample(config=g)
|
||||
net = NSGANetV2.build_from_config(subnet.config,
|
||||
drop_connect_rate=self.drop_path)
|
||||
inputs = torch.randn(1, 3, self.img_size, self.img_size)
|
||||
flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
|
||||
if flops <= bound:
|
||||
searched_g = g
|
||||
break
|
||||
else:
|
||||
searched_g = g
|
||||
pred_acc_lst.append(file_acc)
|
||||
break
|
||||
|
||||
if searched_g is None:
|
||||
raise ValueError(searched_g)
|
||||
return searched_g, pred_acc_lst
|
||||
|
||||
|
||||
def sample(self, config=None):
|
||||
""" randomly sample a sub-network """
|
||||
if config is not None:
|
||||
config = validate_config(config)
|
||||
self.engine.set_active_subnet(ks=config['ks'], e=config['e'], d=config['d'])
|
||||
else:
|
||||
config = self.engine.sample_active_subnet()
|
||||
|
||||
subnet = self.engine.get_active_subnet(preserve_weight=True)
|
||||
return subnet, config
|
||||
|
||||
@staticmethod
|
||||
def save_net_config(path, net, config_name='net.config'):
|
||||
""" dump run_config and net_config to the model_folder """
|
||||
net_save_path = os.path.join(path, config_name)
|
||||
json.dump(net.config, open(net_save_path, 'w'), indent=4)
|
||||
print('Network configs dump to %s' % net_save_path)
|
||||
|
||||
@staticmethod
|
||||
def save_net(path, net, model_name):
|
||||
""" dump net weight as checkpoint """
|
||||
if isinstance(net, torch.nn.DataParallel):
|
||||
checkpoint = {'state_dict': net.module.state_dict()}
|
||||
else:
|
||||
checkpoint = {'state_dict': net.state_dict()}
|
||||
model_path = os.path.join(path, model_name)
|
||||
torch.save(checkpoint, model_path)
|
||||
print('Network model dump to %s' % model_path)
|
||||
@@ -0,0 +1,169 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import random
|
||||
import torch.optim as optim
|
||||
from evaluator import OFAEvaluator
|
||||
from torchprofile import profile_macs
|
||||
from codebase.networks import NSGANetV2
|
||||
from parser import get_parse
|
||||
from eval_utils import get_dataset
|
||||
|
||||
|
||||
args = get_parse()
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
||||
device_list = [int(_) for _ in args.gpu.split(',')]
|
||||
args.n_gpus = len(device_list)
|
||||
args.device = torch.device("cuda:0")
|
||||
|
||||
if args.seed is None or args.seed < 0: args.seed = random.randint(1, 100000)
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
|
||||
|
||||
evaluator = OFAEvaluator(args,
|
||||
model_path='../.torch/ofa_nets/ofa_mbv3_d234_e346_k357_w1.0')
|
||||
|
||||
args.save_path = os.path.join(args.save_path, f'evaluation/{args.data_name}')
|
||||
if args.model_config.startswith('flops@'):
|
||||
args.save_path += f'-nsganetV2-{args.model_config}-{args.seed}'
|
||||
else:
|
||||
args.save_path += f'-metaD2A-{args.bound}-{args.seed}'
|
||||
if not os.path.exists(args.save_path):
|
||||
os.makedirs(args.save_path)
|
||||
|
||||
args.data_path = os.path.join(args.data_path, args.data_name)
|
||||
|
||||
log_format = '%(asctime)s %(message)s'
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
|
||||
format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
||||
fh = logging.FileHandler(os.path.join(args.save_path, 'log.txt'))
|
||||
fh.setFormatter(logging.Formatter(log_format))
|
||||
logging.getLogger().addHandler(fh)
|
||||
if not torch.cuda.is_available():
|
||||
logging.info('no gpu self.args.device available')
|
||||
sys.exit(1)
|
||||
logging.info("args = %s", args)
|
||||
|
||||
|
||||
|
||||
def set_architecture(n_cls):
|
||||
if args.model_config.startswith('flops@'):
|
||||
names = {'cifar10': 'CIFAR-10', 'cifar100': 'CIFAR-100',
|
||||
'aircraft100': 'Aircraft', 'pets': 'Pets'}
|
||||
p = os.path.join('./searched-architectures/{}/net-{}/net.subnet'.
|
||||
format(names[args.data_name], args.model_config))
|
||||
g = json.load(open(p))
|
||||
else:
|
||||
g, acc = evaluator.get_architecture(args)
|
||||
|
||||
subnet, config = evaluator.sample(g)
|
||||
net = NSGANetV2.build_from_config(subnet.config, drop_connect_rate=args.drop_path)
|
||||
net.load_state_dict(subnet.state_dict())
|
||||
|
||||
NSGANetV2.reset_classifier(
|
||||
net, last_channel=net.classifier.in_features,
|
||||
n_classes=n_cls, dropout_rate=args.drop)
|
||||
# calculate #Paramaters and #FLOPS
|
||||
inputs = torch.randn(1, 3, args.img_size, args.img_size)
|
||||
flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
|
||||
params = sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6
|
||||
net_name = "net_flops@{:.0f}".format(flops)
|
||||
logging.info('#params {:.2f}M, #flops {:.0f}M'.format(params, flops))
|
||||
OFAEvaluator.save_net_config(args.save_path, net, net_name + '.config')
|
||||
if args.n_gpus > 1:
|
||||
net = nn.DataParallel(net) # data parallel in case more than 1 gpu available
|
||||
net = net.to(args.device)
|
||||
|
||||
return net, net_name
|
||||
|
||||
|
||||
def train(train_queue, net, criterion, optimizer):
|
||||
net.train()
|
||||
train_loss, correct, total = 0, 0, 0
|
||||
for step, (inputs, targets) in enumerate(train_queue):
|
||||
# upsample by bicubic to match imagenet training size
|
||||
inputs, targets = inputs.to(args.device), targets.to(args.device)
|
||||
optimizer.zero_grad()
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(net.parameters(), args.grad_clip)
|
||||
optimizer.step()
|
||||
train_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
if step % args.report_freq == 0:
|
||||
logging.info('train %03d %e %f', step, train_loss / total, 100. * correct / total)
|
||||
logging.info('train acc %f', 100. * correct / total)
|
||||
return train_loss / total, 100. * correct / total
|
||||
|
||||
|
||||
def infer(valid_queue, net, criterion, early_stop=False):
|
||||
net.eval()
|
||||
test_loss, correct, total = 0, 0, 0
|
||||
with torch.no_grad():
|
||||
for step, (inputs, targets) in enumerate(valid_queue):
|
||||
inputs, targets = inputs.to(args.device), targets.to(args.device)
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
test_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
if step % args.report_freq == 0:
|
||||
logging.info('valid %03d %e %f', step, test_loss / total, 100. * correct / total)
|
||||
if early_stop and step == 10:
|
||||
break
|
||||
acc = 100. * correct / total
|
||||
logging.info('valid acc %f', 100. * correct / total)
|
||||
|
||||
return test_loss / total, acc
|
||||
|
||||
|
||||
def main():
|
||||
best_acc, top_checkpoints = 0, []
|
||||
|
||||
train_queue, valid_queue, n_cls = get_dataset(args)
|
||||
net, net_name = set_architecture(n_cls)
|
||||
parameters = filter(lambda p: p.requires_grad, net.parameters())
|
||||
optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
criterion = nn.CrossEntropyLoss().to(args.device)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
|
||||
|
||||
train(train_queue, net, criterion, optimizer)
|
||||
_, valid_acc = infer(valid_queue, net, criterion)
|
||||
# checkpoint saving
|
||||
|
||||
if len(top_checkpoints) < args.topk:
|
||||
OFAEvaluator.save_net(args.save_path, net, net_name + '.ckpt{}'.format(epoch))
|
||||
top_checkpoints.append((os.path.join(args.save_path, net_name + '.ckpt{}'.format(epoch)), valid_acc))
|
||||
else:
|
||||
idx = np.argmin([x[1] for x in top_checkpoints])
|
||||
if valid_acc > top_checkpoints[idx][1]:
|
||||
OFAEvaluator.save_net(args.save_path, net, net_name + '.ckpt{}'.format(epoch))
|
||||
top_checkpoints.append((os.path.join(args.save_path, net_name + '.ckpt{}'.format(epoch)), valid_acc))
|
||||
# remove the idx
|
||||
os.remove(top_checkpoints[idx][0])
|
||||
top_checkpoints.pop(idx)
|
||||
print(top_checkpoints)
|
||||
if valid_acc > best_acc:
|
||||
OFAEvaluator.save_net(args.save_path, net, net_name + '.best')
|
||||
best_acc = valid_acc
|
||||
scheduler.step()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,43 @@
|
||||
import argparse
|
||||
|
||||
def get_parse():
|
||||
parser = argparse.ArgumentParser(description='MetaD2A vs NSGANETv2')
|
||||
parser.add_argument('--save-path', type=str, default='../results', help='the path of save directory')
|
||||
parser.add_argument('--data-path', type=str, default='../data', help='the path of save directory')
|
||||
parser.add_argument('--data-name', type=str, default=None, help='meta-test dataset name')
|
||||
parser.add_argument('--num-gen-arch', type=int, default=200,
|
||||
help='the number of candidate architectures generated by the generator')
|
||||
parser.add_argument('--bound', type=int, default=None)
|
||||
|
||||
# original setting
|
||||
parser.add_argument('--seed', type=int, default=-1, help='random seed')
|
||||
parser.add_argument('--batch-size', type=int, default=96, help='batch size')
|
||||
parser.add_argument('--num_workers', type=int, default=2, help='number of workers for data loading')
|
||||
parser.add_argument('--gpu', type=str, default='0', help='set visible gpus')
|
||||
parser.add_argument('--lr', type=float, default=0.01, help='init learning rate')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
|
||||
parser.add_argument('--weight_decay', type=float, default=4e-5, help='weight decay')
|
||||
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
|
||||
parser.add_argument('--epochs', type=int, default=150, help='num of training epochs')
|
||||
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
|
||||
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
|
||||
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
|
||||
parser.add_argument('--autoaugment', action='store_true', default=True, help='use auto augmentation')
|
||||
|
||||
parser.add_argument('--topk', type=int, default=10, help='top k checkpoints to save')
|
||||
parser.add_argument('--evaluate', action='store_true', default=False, help='evaluate a pretrained model')
|
||||
# model related
|
||||
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
|
||||
help='Name of model to train (default: "countception"')
|
||||
parser.add_argument('--model-config', type=str, default='search',
|
||||
help='location of a json file of specific model declaration')
|
||||
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
|
||||
help='Initialize model from this checkpoint (default: none)')
|
||||
parser.add_argument('--drop', type=float, default=0.2,
|
||||
help='dropout rate')
|
||||
parser.add_argument('--drop-path', type=float, default=0.2, metavar='PCT',
|
||||
help='Drop path rate (default: None)')
|
||||
parser.add_argument('--img-size', type=int, default=224,
|
||||
help='input resolution (192 -> 256)')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@@ -0,0 +1,261 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import random
|
||||
import torch.optim as optim
|
||||
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.evaluator import OFAEvaluator
|
||||
from torchprofile import profile_macs
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.networks import NSGANetV2
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.parser import get_parse
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.eval_utils import get_dataset
|
||||
from transfer_nag_lib.MetaD2A_nas_bench_201.metad2a_utils import reset_seed
|
||||
from transfer_nag_lib.ofa_net import OFASubNet
|
||||
|
||||
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
||||
# device_list = [int(_) for _ in args.gpu.split(',')]
|
||||
# args.n_gpus = len(device_list)
|
||||
# args.device = torch.device("cuda:0")
|
||||
|
||||
# if args.seed is None or args.seed < 0: args.seed = random.randint(1, 100000)
|
||||
# torch.cuda.manual_seed(args.seed)
|
||||
# torch.manual_seed(args.seed)
|
||||
# np.random.seed(args.seed)
|
||||
# random.seed(args.seed)
|
||||
|
||||
|
||||
|
||||
# args.save_path = os.path.join(args.save_path, f'evaluation/{args.data_name}')
|
||||
# if args.model_config.startswith('flops@'):
|
||||
# args.save_path += f'-nsganetV2-{args.model_config}-{args.seed}'
|
||||
# else:
|
||||
# args.save_path += f'-metaD2A-{args.bound}-{args.seed}'
|
||||
# if not os.path.exists(args.save_path):
|
||||
# os.makedirs(args.save_path)
|
||||
|
||||
# args.data_path = os.path.join(args.data_path, args.data_name)
|
||||
|
||||
# log_format = '%(asctime)s %(message)s'
|
||||
# logging.basicConfig(stream=sys.stdout, level=print,
|
||||
# format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
||||
# fh = logging.FileHandler(os.path.join(args.save_path, 'log.txt'))
|
||||
# fh.setFormatter(logging.Formatter(log_format))
|
||||
# logging.getLogger().addHandler(fh)
|
||||
# if not torch.cuda.is_available():
|
||||
# print('no gpu self.args.device available')
|
||||
# sys.exit(1)
|
||||
# print("args = %s", args)
|
||||
|
||||
|
||||
|
||||
def set_architecture(n_cls, evaluator, drop_path, drop, img_size, n_gpus, device, save_path, model_str):
|
||||
# g, acc = evaluator.get_architecture(model_str)
|
||||
g = OFASubNet(model_str).get_op_dict()
|
||||
subnet, config = evaluator.sample(g)
|
||||
net = NSGANetV2.build_from_config(subnet.config, drop_connect_rate=drop_path)
|
||||
net.load_state_dict(subnet.state_dict())
|
||||
|
||||
NSGANetV2.reset_classifier(
|
||||
net, last_channel=net.classifier.in_features,
|
||||
n_classes=n_cls, dropout_rate=drop)
|
||||
# calculate #Paramaters and #FLOPS
|
||||
inputs = torch.randn(1, 3, img_size, img_size)
|
||||
flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
|
||||
params = sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6
|
||||
net_name = "net_flops@{:.0f}".format(flops)
|
||||
print('#params {:.2f}M, #flops {:.0f}M'.format(params, flops))
|
||||
# OFAEvaluator.save_net_config(save_path, net, net_name + '.config')
|
||||
if torch.cuda.device_count() > 1:
|
||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
net = nn.DataParallel(net)
|
||||
net = net.to(device)
|
||||
|
||||
return net, net_name, params, flops
|
||||
|
||||
|
||||
def train(train_queue, net, criterion, optimizer, grad_clip, device, report_freq):
|
||||
net.train()
|
||||
train_loss, correct, total = 0, 0, 0
|
||||
for step, (inputs, targets) in enumerate(train_queue):
|
||||
# upsample by bicubic to match imagenet training size
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
optimizer.zero_grad()
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
|
||||
optimizer.step()
|
||||
train_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
if step % report_freq == 0:
|
||||
print(f'train step {step:03d} loss {train_loss / total:.4f} train acc {100. * correct / total:.4f}')
|
||||
print(f'train acc {100. * correct / total:.4f}')
|
||||
return train_loss / total, 100. * correct / total
|
||||
|
||||
|
||||
def infer(valid_queue, net, criterion, device, report_freq, early_stop=False):
|
||||
net.eval()
|
||||
test_loss, correct, total = 0, 0, 0
|
||||
with torch.no_grad():
|
||||
for step, (inputs, targets) in enumerate(valid_queue):
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
test_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
if step % report_freq == 0:
|
||||
print(f'valid {step:03d} {test_loss / total:.4f} {100. * correct / total:.4f}')
|
||||
if early_stop and step == 10:
|
||||
break
|
||||
acc = 100. * correct / total
|
||||
print('valid acc {:.4f}'.format(100. * correct / total))
|
||||
|
||||
return test_loss / total, acc
|
||||
|
||||
|
||||
def train_single_model(save_path, workers, datasets, xpaths, splits, use_less,
|
||||
seed, model_str, device,
|
||||
lr=0.01,
|
||||
momentum=0.9,
|
||||
weight_decay=4e-5,
|
||||
report_freq=50,
|
||||
epochs=150,
|
||||
grad_clip=5,
|
||||
cutout=True,
|
||||
cutout_length=16,
|
||||
autoaugment=True,
|
||||
drop=0.2,
|
||||
drop_path=0.2,
|
||||
img_size=224,
|
||||
batch_size=96,
|
||||
):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.deterministic = True
|
||||
reset_seed(seed)
|
||||
# save_dir = Path(save_dir)
|
||||
# logger = Logger(str(save_dir), 0, False)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
to_save_name = save_path + '/seed-{:04d}.pth'.format(seed)
|
||||
print(to_save_name)
|
||||
# args = get_parse()
|
||||
num_gen_arch = None
|
||||
evaluator = OFAEvaluator(num_gen_arch, img_size, drop_path,
|
||||
model_path='/home/data/GTAD/checkpoints/ofa/ofa_net/ofa_mbv3_d234_e346_k357_w1.0')
|
||||
|
||||
train_queue, valid_queue, n_cls = get_dataset(datasets, batch_size,
|
||||
xpaths, workers, img_size, autoaugment, cutout, cutout_length)
|
||||
net, net_name, params, flops = set_architecture(n_cls, evaluator,
|
||||
drop_path, drop, img_size, n_gpus=1, device=device, save_path=save_path, model_str=model_str)
|
||||
|
||||
|
||||
# net.to(device)
|
||||
|
||||
parameters = filter(lambda p: p.requires_grad, net.parameters())
|
||||
optimizer = optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay)
|
||||
criterion = nn.CrossEntropyLoss().to(device)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
|
||||
|
||||
# assert epochs == 1
|
||||
max_valid_acc = 0
|
||||
max_epoch = 0
|
||||
for epoch in range(epochs):
|
||||
print('epoch {:d} lr {:.4f}'.format(epoch, scheduler.get_lr()[0]))
|
||||
|
||||
train(train_queue, net, criterion, optimizer, grad_clip, device, report_freq)
|
||||
_, valid_acc = infer(valid_queue, net, criterion, device, report_freq)
|
||||
torch.save(valid_acc, to_save_name)
|
||||
print(f'seed {seed:04d} last acc {valid_acc:.4f} max acc {max_valid_acc:.4f}')
|
||||
if max_valid_acc < valid_acc:
|
||||
max_valid_acc = valid_acc
|
||||
max_epoch = epoch
|
||||
# parent_path = os.path.abspath(os.path.join(save_path, os.pardir))
|
||||
# with open(parent_path + '/accuracy.txt', 'a+') as f:
|
||||
# f.write(f'{model_str} seed {seed:04d} {valid_acc:.4f}\n')
|
||||
|
||||
return valid_acc, max_valid_acc, params, flops
|
||||
|
||||
|
||||
################ NAS BENCH 201 #####################
|
||||
# def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less,
|
||||
# seeds, model_str, arch_config):
|
||||
# assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
# torch.backends.cudnn.enabled = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.set_num_threads(workers)
|
||||
|
||||
# save_dir = Path(save_dir)
|
||||
# logger = Logger(str(save_dir), 0, False)
|
||||
|
||||
# if model_str in CellArchitectures:
|
||||
# arch = CellArchitectures[model_str]
|
||||
# logger.log(
|
||||
# 'The model string is found in pre-defined architecture dict : {:}'.format(model_str))
|
||||
# else:
|
||||
# try:
|
||||
# arch = CellStructure.str2structure(model_str)
|
||||
# except:
|
||||
# raise ValueError(
|
||||
# 'Invalid model string : {:}. It can not be found or parsed.'.format(model_str))
|
||||
|
||||
# assert arch.check_valid_op(get_search_spaces(
|
||||
# 'cell', 'nas-bench-201')), '{:} has the invalid op.'.format(arch)
|
||||
# # assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch)
|
||||
# logger.log('Start train-evaluate {:}'.format(arch.tostr()))
|
||||
# logger.log('arch_config : {:}'.format(arch_config))
|
||||
|
||||
# start_time, seed_time = time.time(), AverageMeter()
|
||||
# for _is, seed in enumerate(seeds):
|
||||
# logger.log(
|
||||
# '\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds),
|
||||
# seed))
|
||||
# to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed)
|
||||
# if to_save_name.exists():
|
||||
# logger.log(
|
||||
# 'Find the existing file {:}, directly load!'.format(to_save_name))
|
||||
# checkpoint = torch.load(to_save_name)
|
||||
# else:
|
||||
# logger.log(
|
||||
# 'Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
|
||||
# checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less,
|
||||
# seed, arch_config, workers, logger)
|
||||
# torch.save(checkpoint, to_save_name)
|
||||
# # log information
|
||||
# logger.log('{:}'.format(checkpoint['info']))
|
||||
# all_dataset_keys = checkpoint['all_dataset_keys']
|
||||
# for dataset_key in all_dataset_keys:
|
||||
# logger.log('\n{:} dataset : {:} {:}'.format(
|
||||
# '-' * 15, dataset_key, '-' * 15))
|
||||
# dataset_info = checkpoint[dataset_key]
|
||||
# # logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
|
||||
# logger.log('Flops = {:} MB, Params = {:} MB'.format(
|
||||
# dataset_info['flop'], dataset_info['param']))
|
||||
# logger.log('config : {:}'.format(dataset_info['config']))
|
||||
# logger.log('Training State (finish) = {:}'.format(
|
||||
# dataset_info['finish-train']))
|
||||
# last_epoch = dataset_info['total_epoch'] - 1
|
||||
# train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es']
|
||||
# valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es']
|
||||
# # measure elapsed time
|
||||
# seed_time.update(time.time() - start_time)
|
||||
# start_time = time.time()
|
||||
# need_time = 'Time Left: {:}'.format(convert_secs2time(
|
||||
# seed_time.avg * (len(seeds) - _is - 1), True))
|
||||
# logger.log(
|
||||
# '\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}'.format(_is, len(seeds), seed,
|
||||
# need_time))
|
||||
# logger.close()
|
||||
# ###################
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_single_model()
|
||||
@@ -0,0 +1,5 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from .generator import Generator
|
||||
@@ -0,0 +1,204 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch import optim
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
from utils import load_graph_config, decode_ofa_mbv3_to_igraph, decode_igraph_to_ofa_mbv3
|
||||
from utils import Accumulator, Log
|
||||
from utils import load_model, save_model
|
||||
from loader import get_meta_train_loader, get_meta_test_loader
|
||||
|
||||
from .generator_model import GeneratorModel
|
||||
|
||||
|
||||
class Generator:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.batch_size = args.batch_size
|
||||
self.data_path = args.data_path
|
||||
self.num_sample = args.num_sample
|
||||
self.max_epoch = args.max_epoch
|
||||
self.save_epoch = args.save_epoch
|
||||
self.model_path = args.model_path
|
||||
self.save_path = args.save_path
|
||||
self.model_name = args.model_name
|
||||
self.test = args.test
|
||||
self.device = args.device
|
||||
|
||||
graph_config = load_graph_config(
|
||||
args.graph_data_name, args.nvt, args.data_path)
|
||||
self.model = GeneratorModel(args, graph_config)
|
||||
self.model.to(self.device)
|
||||
|
||||
if self.test:
|
||||
self.data_name = args.data_name
|
||||
self.num_class = args.num_class
|
||||
self.load_epoch = args.load_epoch
|
||||
self.num_gen_arch = args.num_gen_arch
|
||||
load_model(self.model, self.model_path, self.load_epoch)
|
||||
|
||||
else:
|
||||
self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4)
|
||||
self.scheduler = ReduceLROnPlateau(self.optimizer, 'min',
|
||||
factor=0.1, patience=10, verbose=True)
|
||||
self.mtrloader = get_meta_train_loader(
|
||||
self.batch_size, self.data_path, self.num_sample)
|
||||
self.mtrlog = Log(self.args, open(os.path.join(
|
||||
self.save_path, self.model_name, 'meta_train_generator.log'), 'w'))
|
||||
self.mtrlog.print_args()
|
||||
self.mtrlogger = Accumulator('loss', 'recon_loss', 'kld')
|
||||
self.mvallogger = Accumulator('loss', 'recon_loss', 'kld')
|
||||
|
||||
def meta_train(self):
|
||||
sttime = time.time()
|
||||
for epoch in range(1, self.max_epoch + 1):
|
||||
self.mtrlog.ep_sttime = time.time()
|
||||
loss = self.meta_train_epoch(epoch)
|
||||
self.scheduler.step(loss)
|
||||
self.mtrlog.print(self.mtrlogger, epoch, tag='train')
|
||||
|
||||
self.meta_validation()
|
||||
self.mtrlog.print(self.mvallogger, epoch, tag='valid')
|
||||
|
||||
if epoch % self.save_epoch == 0:
|
||||
save_model(epoch, self.model, self.model_path)
|
||||
|
||||
self.mtrlog.save_time_log()
|
||||
|
||||
def meta_train_epoch(self, epoch):
|
||||
self.model.to(self.device)
|
||||
self.model.train()
|
||||
|
||||
self.mtrloader.dataset.set_mode('train')
|
||||
pbar = tqdm(self.mtrloader)
|
||||
|
||||
for batch in pbar:
|
||||
for x, g, acc in batch:
|
||||
self.optimizer.zero_grad()
|
||||
g = decode_ofa_mbv3_to_igraph(g)[0]
|
||||
x_ = x.unsqueeze(0).to(self.device)
|
||||
mu, logvar = self.model.set_encode(x_)
|
||||
loss, recon, kld = self.model.loss(mu.unsqueeze(0), logvar.unsqueeze(0), [g])
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
cnt = len(x)
|
||||
self.mtrlogger.accum([loss.item() / cnt,
|
||||
recon.item() / cnt,
|
||||
kld.item() / cnt])
|
||||
|
||||
return self.mtrlogger.get('loss')
|
||||
|
||||
|
||||
def meta_validation(self):
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
self.mtrloader.dataset.set_mode('valid')
|
||||
pbar = tqdm(self.mtrloader)
|
||||
|
||||
for batch in pbar:
|
||||
for x, g, acc in batch:
|
||||
with torch.no_grad():
|
||||
g = decode_ofa_mbv3_to_igraph(g)[0]
|
||||
x_ = x.unsqueeze(0).to(self.device)
|
||||
mu, logvar = self.model.set_encode(x_)
|
||||
loss, recon, kld = self.model.loss(mu.unsqueeze(0), logvar.unsqueeze(0), [g])
|
||||
|
||||
cnt = len(x)
|
||||
self.mvallogger.accum([loss.item() / cnt,
|
||||
recon.item() / cnt,
|
||||
kld.item() / cnt])
|
||||
|
||||
return self.mvallogger.get('loss')
|
||||
|
||||
|
||||
def meta_test(self, predictor):
|
||||
if self.data_name == 'all':
|
||||
for data_name in ['cifar100', 'cifar10', 'mnist', 'svhn', 'aircraft30', 'aircraft100', 'pets']:
|
||||
self.meta_test_per_dataset(data_name, predictor)
|
||||
else:
|
||||
self.meta_test_per_dataset(self.data_name, predictor)
|
||||
|
||||
def meta_test_per_dataset(self, data_name, predictor):
|
||||
# meta_test_path = os.path.join(
|
||||
# self.save_path, 'meta_test', data_name, 'generated_arch')
|
||||
meta_test_path = os.path.join(
|
||||
self.save_path, 'meta_test', data_name, f'{self.num_gen_arch}', 'generated_arch')
|
||||
if not os.path.exists(meta_test_path):
|
||||
os.makedirs(meta_test_path)
|
||||
|
||||
meta_test_loader = get_meta_test_loader(
|
||||
self.data_path, data_name, self.num_sample, self.num_class)
|
||||
|
||||
print(f'==> generate architectures for {data_name}')
|
||||
runs = 10 if data_name in ['cifar10', 'cifar100'] else 1
|
||||
# num_gen_arch = 500 if data_name in ['cifar100'] else self.num_gen_arch
|
||||
elasped_time = []
|
||||
for run in range(1, runs + 1):
|
||||
print(f'==> run {run}/{runs}')
|
||||
elasped_time.append(self.generate_architectures(
|
||||
meta_test_loader, data_name,
|
||||
meta_test_path, run, self.num_gen_arch, predictor))
|
||||
print(f'==> done\n')
|
||||
|
||||
# time_path = os.path.join(self.save_path, 'meta_test', data_name, 'time.txt')
|
||||
time_path = os.path.join(self.save_path, 'meta_test', data_name, f'{self.num_gen_arch}', 'time.txt')
|
||||
with open(time_path, 'w') as f_time:
|
||||
msg = f'generator elasped time {np.mean(elasped_time):.2f}s'
|
||||
print(f'==> save time in {time_path}')
|
||||
f_time.write(msg + '\n');
|
||||
print(msg)
|
||||
|
||||
def generate_architectures(self, meta_test_loader, data_name,
|
||||
meta_test_path, run, num_gen_arch, predictor):
|
||||
self.model.eval()
|
||||
self.model.to(self.device)
|
||||
|
||||
architecture_string_lst, pred_acc_lst = [], []
|
||||
total_cnt, valid_cnt = 0, 0
|
||||
flag = False
|
||||
|
||||
start = time.time()
|
||||
with torch.no_grad():
|
||||
for x in meta_test_loader:
|
||||
x_ = x.unsqueeze(0).to(self.device)
|
||||
mu, logvar = self.model.set_encode(x_)
|
||||
z = self.model.reparameterize(mu.unsqueeze(0), logvar.unsqueeze(0))
|
||||
g_recon = self.model.graph_decode(z)
|
||||
pred_acc = predictor.forward(x_, g_recon)
|
||||
architecture_string = decode_igraph_to_ofa_mbv3(g_recon[0])
|
||||
total_cnt += 1
|
||||
if architecture_string is not None:
|
||||
if not architecture_string in architecture_string_lst:
|
||||
valid_cnt += 1
|
||||
architecture_string_lst.append(architecture_string)
|
||||
pred_acc_lst.append(pred_acc.item())
|
||||
if valid_cnt == num_gen_arch:
|
||||
flag = True
|
||||
break
|
||||
if flag:
|
||||
break
|
||||
elapsed = time.time() - start
|
||||
pred_acc_lst, architecture_string_lst = zip(*sorted(zip(pred_acc_lst,
|
||||
architecture_string_lst),
|
||||
key=lambda x: x[0], reverse=True))
|
||||
|
||||
spath = os.path.join(meta_test_path, f"run_{run}.txt")
|
||||
with open(spath, 'w') as f:
|
||||
print(f'==> save generated architectures in {spath}')
|
||||
msg = f'elapsed time: {elapsed:6.2f}s '
|
||||
print(msg);
|
||||
f.write(msg + '\n')
|
||||
for i, architecture_string in enumerate(architecture_string_lst):
|
||||
f.write(f"{architecture_string}\n")
|
||||
return elapsed
|
||||
@@ -0,0 +1,396 @@
|
||||
######################################################################################
|
||||
# Copyright (c) muhanzhang, D-VAE, NeurIPS 2019 [GitHub D-VAE]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
import igraph
|
||||
from set_encoder.setenc_models import SetPool
|
||||
|
||||
|
||||
class GeneratorModel(nn.Module):
|
||||
def __init__(self, args, graph_config):
|
||||
super(GeneratorModel, self).__init__()
|
||||
self.max_n = graph_config['max_n'] # maximum number of vertices
|
||||
self.nvt = graph_config['num_vertex_type'] # number of vertex types
|
||||
self.START_TYPE = graph_config['START_TYPE']
|
||||
self.END_TYPE = graph_config['END_TYPE']
|
||||
self.hs = args.hs # hidden state size of each vertex
|
||||
self.nz = args.nz # size of latent representation z
|
||||
self.gs = args.hs # size of graph state
|
||||
self.bidir = True # whether to use bidirectional encoding
|
||||
self.vid = True
|
||||
self.device = None
|
||||
self.num_sample = args.num_sample
|
||||
|
||||
if self.vid:
|
||||
self.vs = self.hs + self.max_n # vertex state size = hidden state + vid
|
||||
else:
|
||||
self.vs = self.hs
|
||||
|
||||
# 0. encoding-related
|
||||
self.grue_forward = nn.GRUCell(self.nvt, self.hs) # encoder GRU
|
||||
self.grue_backward = nn.GRUCell(self.nvt, self.hs) # backward encoder GRU
|
||||
self.enc_g_mu = nn.Linear(self.gs, self.nz) # latent mean
|
||||
self.enc_g_var = nn.Linear(self.gs, self.nz) # latent var
|
||||
self.fc1 = nn.Linear(self.gs, self.nz) # latent mean
|
||||
self.fc2 = nn.Linear(self.gs, self.nz) # latent logvar
|
||||
|
||||
# 1. decoding-related
|
||||
self.grud = nn.GRUCell(self.nvt, self.hs) # decoder GRU
|
||||
self.fc3 = nn.Linear(self.nz, self.hs) # from latent z to initial hidden state h0
|
||||
self.add_vertex = nn.Sequential(
|
||||
nn.Linear(self.hs, self.hs * 2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.hs * 2, self.nvt)
|
||||
) # which type of new vertex to add f(h0, hg)
|
||||
self.add_edge = nn.Sequential(
|
||||
nn.Linear(self.hs * 2, self.hs * 4),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.hs * 4, 1)
|
||||
) # whether to add edge between v_i and v_new, f(hvi, hnew)
|
||||
self.decoding_gate = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
self.decoding_mapper = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs, bias=False),
|
||||
) # disable bias to ensure padded zeros also mapped to zeros
|
||||
|
||||
# 2. gate-related
|
||||
self.gate_forward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
self.gate_backward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
self.mapper_forward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs, bias=False),
|
||||
) # disable bias to ensure padded zeros also mapped to zeros
|
||||
self.mapper_backward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs, bias=False),
|
||||
)
|
||||
|
||||
# 3. bidir-related, to unify sizes
|
||||
if self.bidir:
|
||||
self.hv_unify = nn.Sequential(
|
||||
nn.Linear(self.hs * 2, self.hs),
|
||||
)
|
||||
self.hg_unify = nn.Sequential(
|
||||
nn.Linear(self.gs * 2, self.gs),
|
||||
)
|
||||
|
||||
# 4. other
|
||||
self.relu = nn.ReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.tanh = nn.Tanh()
|
||||
self.logsoftmax1 = nn.LogSoftmax(1)
|
||||
|
||||
# 6. predictor
|
||||
np = self.gs
|
||||
self.intra_setpool = SetPool(dim_input=512,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.inter_setpool = SetPool(dim_input=self.nz,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.set_fc = nn.Sequential(
|
||||
nn.Linear(512, self.nz),
|
||||
nn.ReLU())
|
||||
|
||||
def get_device(self):
|
||||
if self.device is None:
|
||||
self.device = next(self.parameters()).device
|
||||
return self.device
|
||||
|
||||
def _get_zeros(self, n, length):
|
||||
return torch.zeros(n, length).to(self.get_device()) # get a zero hidden state
|
||||
|
||||
def _get_zero_hidden(self, n=1):
|
||||
return self._get_zeros(n, self.hs) # get a zero hidden state
|
||||
|
||||
def _one_hot(self, idx, length):
|
||||
if type(idx) in [list, range]:
|
||||
if idx == []:
|
||||
return None
|
||||
idx = torch.LongTensor(idx).unsqueeze(0).t()
|
||||
x = torch.zeros((len(idx), length)
|
||||
).scatter_(1, idx, 1).to(self.get_device())
|
||||
else:
|
||||
idx = torch.LongTensor([idx]).unsqueeze(0)
|
||||
x = torch.zeros((1, length)
|
||||
).scatter_(1, idx, 1).to(self.get_device())
|
||||
return x
|
||||
|
||||
def _gated(self, h, gate, mapper):
|
||||
return gate(h) * mapper(h)
|
||||
|
||||
def _collate_fn(self, G):
|
||||
return [g.copy() for g in G]
|
||||
|
||||
def _propagate_to(self, G, v, propagator,
|
||||
H=None, reverse=False, gate=None, mapper=None):
|
||||
# propagate messages to vertex index v for all graphs in G
|
||||
# return the new messages (states) at v
|
||||
G = [g for g in G if g.vcount() > v]
|
||||
if len(G) == 0:
|
||||
return
|
||||
if H is not None:
|
||||
idx = [i for i, g in enumerate(G) if g.vcount() > v]
|
||||
H = H[idx]
|
||||
v_types = [g.vs[v]['type'] for g in G]
|
||||
X = self._one_hot(v_types, self.nvt)
|
||||
H_name = 'H_forward' # name of the hidden states attribute
|
||||
H_pred = [[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
|
||||
if self.vid:
|
||||
vids = [self._one_hot(g.predecessors(v), self.max_n) for g in G]
|
||||
if reverse:
|
||||
H_name = 'H_backward' # name of the hidden states attribute
|
||||
H_pred = [[g.vs[x][H_name] for x in g.successors(v)] for g in G]
|
||||
if self.vid:
|
||||
vids = [self._one_hot(g.successors(v), self.max_n) for g in G]
|
||||
gate, mapper = self.gate_backward, self.mapper_backward
|
||||
else:
|
||||
H_name = 'H_forward' # name of the hidden states attribute
|
||||
H_pred = [
|
||||
[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
|
||||
if self.vid:
|
||||
vids = [
|
||||
self._one_hot(g.predecessors(v), self.max_n) for g in G]
|
||||
if gate is None:
|
||||
gate, mapper = self.gate_forward, self.mapper_forward
|
||||
if self.vid:
|
||||
H_pred = [[torch.cat(
|
||||
[x[i], y[i:i + 1]], 1) for i in range(len(x))
|
||||
] for x, y in zip(H_pred, vids)]
|
||||
# if h is not provided, use gated sum of v's predecessors' states as the input hidden state
|
||||
if H is None:
|
||||
max_n_pred = max([len(x) for x in H_pred]) # maximum number of predecessors
|
||||
if max_n_pred == 0:
|
||||
H = self._get_zero_hidden(len(G))
|
||||
else:
|
||||
H_pred = [torch.cat(h_pred +
|
||||
[self._get_zeros(max_n_pred - len(h_pred),
|
||||
self.vs)], 0).unsqueeze(0)
|
||||
for h_pred in H_pred] # pad all to same length
|
||||
H_pred = torch.cat(H_pred, 0) # batch * max_n_pred * vs
|
||||
H = self._gated(H_pred, gate, mapper).sum(1) # batch * hs
|
||||
Hv = propagator(X, H)
|
||||
for i, g in enumerate(G):
|
||||
g.vs[v][H_name] = Hv[i:i + 1]
|
||||
return Hv
|
||||
|
||||
def _propagate_from(self, G, v, propagator, H0=None, reverse=False):
|
||||
# perform a series of propagation_to steps starting from v following a topo order
|
||||
# assume the original vertex indices are in a topological order
|
||||
if reverse:
|
||||
prop_order = range(v, -1, -1)
|
||||
else:
|
||||
prop_order = range(v, self.max_n)
|
||||
Hv = self._propagate_to(G, v, propagator, H0, reverse=reverse) # the initial vertex
|
||||
for v_ in prop_order[1:]:
|
||||
self._propagate_to(G, v_, propagator, reverse=reverse)
|
||||
return Hv
|
||||
|
||||
def _update_v(self, G, v, H0=None):
|
||||
# perform a forward propagation step at v when decoding to update v's state
|
||||
# self._propagate_to(G, v, self.grud, H0, reverse=False)
|
||||
self._propagate_to(G, v, self.grud, H0,
|
||||
reverse=False, gate=self.decoding_gate,
|
||||
mapper=self.decoding_mapper)
|
||||
return
|
||||
|
||||
def _get_vertex_state(self, G, v):
|
||||
# get the vertex states at v
|
||||
Hv = []
|
||||
for g in G:
|
||||
if v >= g.vcount():
|
||||
hv = self._get_zero_hidden()
|
||||
else:
|
||||
hv = g.vs[v]['H_forward']
|
||||
Hv.append(hv)
|
||||
Hv = torch.cat(Hv, 0)
|
||||
return Hv
|
||||
|
||||
def _get_graph_state(self, G, decode=False):
|
||||
# get the graph states
|
||||
# when decoding, use the last generated vertex's state as the graph state
|
||||
# when encoding, use the ending vertex state or unify the starting and ending vertex states
|
||||
Hg = []
|
||||
for g in G:
|
||||
hg = g.vs[g.vcount() - 1]['H_forward']
|
||||
if self.bidir and not decode: # decoding never uses backward propagation
|
||||
hg_b = g.vs[0]['H_backward']
|
||||
hg = torch.cat([hg, hg_b], 1)
|
||||
Hg.append(hg)
|
||||
Hg = torch.cat(Hg, 0)
|
||||
if self.bidir and not decode:
|
||||
Hg = self.hg_unify(Hg)
|
||||
return Hg
|
||||
|
||||
def graph_encode(self, G):
|
||||
# encode graphs G into latent vectors
|
||||
if type(G) != list:
|
||||
G = [G]
|
||||
self._propagate_from(G, 0, self.grue_forward,
|
||||
H0=self._get_zero_hidden(len(G)), reverse=False)
|
||||
if self.bidir:
|
||||
self._propagate_from(G, self.max_n - 1, self.grue_backward,
|
||||
H0=self._get_zero_hidden(len(G)), reverse=True)
|
||||
Hg = self._get_graph_state(G)
|
||||
mu, logvar = self.enc_g_mu(Hg), self.enc_g_var(Hg)
|
||||
return mu, logvar
|
||||
|
||||
def set_encode(self, X):
|
||||
proto_batch = []
|
||||
for x in X: # X.shape: [32, 400, 512]
|
||||
cls_protos = self.intra_setpool(
|
||||
x.view(-1, self.num_sample, 512)).squeeze(1)
|
||||
proto_batch.append(
|
||||
self.inter_setpool(cls_protos.unsqueeze(0)))
|
||||
v = torch.stack(proto_batch).squeeze()
|
||||
mu, logvar = self.fc1(v), self.fc2(v)
|
||||
return mu, logvar
|
||||
|
||||
def reparameterize(self, mu, logvar, eps_scale=0.01):
|
||||
# return z ~ N(mu, std)
|
||||
if self.training:
|
||||
std = logvar.mul(0.5).exp_()
|
||||
eps = torch.randn_like(std) * eps_scale
|
||||
return eps.mul(std).add_(mu)
|
||||
else:
|
||||
return mu
|
||||
|
||||
def _get_edge_score(self, Hvi, H, H0):
|
||||
# compute scores for edges from vi based on Hvi, H (current vertex) and H0
|
||||
# in most cases, H0 need not be explicitly included since Hvi and H contain its information
|
||||
return self.sigmoid(self.add_edge(torch.cat([Hvi, H], -1)))
|
||||
|
||||
def graph_decode(self, z, stochastic=True):
|
||||
# decode latent vectors z back to graphs
|
||||
# if stochastic=True, stochastically sample each action from the predicted distribution;
|
||||
# otherwise, select argmax action deterministically.
|
||||
H0 = self.tanh(self.fc3(z)) # or relu activation, similar performance
|
||||
G = [igraph.Graph(directed=True) for _ in range(len(z))]
|
||||
for g in G:
|
||||
g.add_vertex(type=self.START_TYPE)
|
||||
self._update_v(G, 0, H0)
|
||||
finished = [False] * len(G)
|
||||
for idx in range(1, self.max_n):
|
||||
# decide the type of the next added vertex
|
||||
if idx == self.max_n - 1: # force the last node to be end_type
|
||||
new_types = [self.END_TYPE] * len(G)
|
||||
else:
|
||||
Hg = self._get_graph_state(G, decode=True)
|
||||
type_scores = self.add_vertex(Hg)
|
||||
if stochastic:
|
||||
type_probs = F.softmax(type_scores, 1
|
||||
).cpu().detach().numpy()
|
||||
new_types = [np.random.choice(range(self.nvt),
|
||||
p=type_probs[i]) for i in range(len(G))]
|
||||
else:
|
||||
new_types = torch.argmax(type_scores, 1)
|
||||
new_types = new_types.flatten().tolist()
|
||||
for i, g in enumerate(G):
|
||||
if not finished[i]:
|
||||
g.add_vertex(type=new_types[i])
|
||||
self._update_v(G, idx)
|
||||
|
||||
# decide connections
|
||||
edge_scores = []
|
||||
for vi in range(idx - 1, -1, -1):
|
||||
Hvi = self._get_vertex_state(G, vi)
|
||||
H = self._get_vertex_state(G, idx)
|
||||
ei_score = self._get_edge_score(Hvi, H, H0)
|
||||
if stochastic:
|
||||
random_score = torch.rand_like(ei_score)
|
||||
decisions = random_score < ei_score
|
||||
else:
|
||||
decisions = ei_score > 0.5
|
||||
for i, g in enumerate(G):
|
||||
if finished[i]:
|
||||
continue
|
||||
if new_types[i] == self.END_TYPE:
|
||||
# if new node is end_type, connect it to all loose-end vertices (out_degree==0)
|
||||
end_vertices = set([
|
||||
v.index for v in g.vs.select(_outdegree_eq=0)
|
||||
if v.index != g.vcount() - 1])
|
||||
for v in end_vertices:
|
||||
g.add_edge(v, g.vcount() - 1)
|
||||
finished[i] = True
|
||||
continue
|
||||
if decisions[i, 0]:
|
||||
g.add_edge(vi, g.vcount() - 1)
|
||||
self._update_v(G, idx)
|
||||
|
||||
for g in G:
|
||||
del g.vs['H_forward'] # delete hidden states to save GPU memory
|
||||
return G
|
||||
|
||||
def loss(self, mu, logvar, G_true, beta=0.005):
|
||||
# compute the loss of decoding mu and logvar to true graphs using teacher forcing
|
||||
# ensure when computing the loss of step i, steps 0 to i-1 are correct
|
||||
z = self.reparameterize(mu, logvar)
|
||||
H0 = self.tanh(self.fc3(z)) # or relu activation, similar performance
|
||||
G = [igraph.Graph(directed=True) for _ in range(len(z))]
|
||||
for g in G:
|
||||
g.add_vertex(type=self.START_TYPE)
|
||||
self._update_v(G, 0, H0)
|
||||
res = 0 # log likelihood
|
||||
for v_true in range(1, self.max_n):
|
||||
# calculate the likelihood of adding true types of nodes
|
||||
# use start type to denote padding vertices since start type only appears for vertex 0
|
||||
# and will never be a true type for later vertices, thus it's free to use
|
||||
true_types = [g_true.vs[v_true]['type']
|
||||
if v_true < g_true.vcount()
|
||||
else self.START_TYPE for g_true in G_true]
|
||||
Hg = self._get_graph_state(G, decode=True)
|
||||
type_scores = self.add_vertex(Hg)
|
||||
# vertex log likelihood
|
||||
vll = self.logsoftmax1(type_scores)[
|
||||
np.arange(len(G)), true_types].sum()
|
||||
res = res + vll
|
||||
for i, g in enumerate(G):
|
||||
if true_types[i] != self.START_TYPE:
|
||||
g.add_vertex(type=true_types[i])
|
||||
self._update_v(G, v_true)
|
||||
|
||||
# calculate the likelihood of adding true edges
|
||||
true_edges = []
|
||||
for i, g_true in enumerate(G_true):
|
||||
true_edges.append(g_true.get_adjlist(igraph.IN)[v_true]
|
||||
if v_true < g_true.vcount() else [])
|
||||
edge_scores = []
|
||||
for vi in range(v_true - 1, -1, -1):
|
||||
Hvi = self._get_vertex_state(G, vi)
|
||||
H = self._get_vertex_state(G, v_true)
|
||||
ei_score = self._get_edge_score(Hvi, H, H0)
|
||||
edge_scores.append(ei_score)
|
||||
for i, g in enumerate(G):
|
||||
if vi in true_edges[i]:
|
||||
g.add_edge(vi, v_true)
|
||||
self._update_v(G, v_true)
|
||||
edge_scores = torch.cat(edge_scores[::-1], 1)
|
||||
|
||||
ground_truth = torch.zeros_like(edge_scores)
|
||||
idx1 = [i for i, x in enumerate(true_edges)
|
||||
for _ in range(len(x))]
|
||||
idx2 = [xx for x in true_edges for xx in x]
|
||||
ground_truth[idx1, idx2] = 1.0
|
||||
|
||||
# edges log-likelihood
|
||||
ell = - F.binary_cross_entropy(
|
||||
edge_scores, ground_truth, reduction='sum')
|
||||
res = res + ell
|
||||
|
||||
res = -res # convert likelihood to loss
|
||||
kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
return res + beta * kld, res, kld
|
||||
@@ -0,0 +1,37 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import zipfile
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
file_name = 'ckpt_120.pt'
|
||||
dir_path = 'results/generator/model'
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
file_name = os.path.join(dir_path, file_name)
|
||||
if not os.path.exists(file_name):
|
||||
print(f"Downloading {file_name}\n")
|
||||
download_file('https://www.dropbox.com/s/zss9yt034hen45h/ckpt_120.pt?dl=1', file_name)
|
||||
print("Downloading done.\n")
|
||||
else:
|
||||
print(f"{file_name} has already been downloaded. Did not download twice.\n")
|
||||
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import zipfile
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
|
||||
file_name = 'collected_database.pt'
|
||||
dir_path = 'data/generator/processed'
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
file_name = os.path.join(dir_path, file_name)
|
||||
if not os.path.exists(file_name):
|
||||
print(f"Downloading generator {file_name}\n")
|
||||
download_file('https://www.dropbox.com/s/zgip4aq0w2pkj49/generator_collected_database.pt?dl=1', file_name)
|
||||
print("Downloading done.\n")
|
||||
else:
|
||||
print(f"{file_name} has already been downloaded. Did not download twice.\n")
|
||||
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import zipfile
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
dir_path = 'data/pets'
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
full_name = os.path.join(dir_path, 'test15.pth')
|
||||
if not os.path.exists(full_name):
|
||||
print(f"Downloading {full_name}\n")
|
||||
download_file('https://www.dropbox.com/s/kzmrwyyk5iaugv0/test15.pth?dl=1', full_name)
|
||||
print("Downloading done.\n")
|
||||
else:
|
||||
print(f"{full_name} has already been downloaded. Did not download twice.\n")
|
||||
|
||||
full_name = os.path.join(dir_path, 'train85.pth')
|
||||
if not os.path.exists(full_name):
|
||||
print(f"Downloading {full_name}\n")
|
||||
download_file('https://www.dropbox.com/s/w7mikpztkamnw9s/train85.pth?dl=1', full_name)
|
||||
print("Downloading done.\n")
|
||||
else:
|
||||
print(f"{full_name} has already been downloaded. Did not download twice.\n")
|
||||
@@ -0,0 +1,35 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import zipfile
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
file_name = 'ckpt_max_corr.pt'
|
||||
dir_path = 'results/predictor/model'
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
file_name = os.path.join(dir_path, file_name)
|
||||
if not os.path.exists(file_name):
|
||||
print(f"Downloading {file_name}\n")
|
||||
download_file('https://www.dropbox.com/s/ycm4jaojgswp0zm/ckpt_max_corr.pt?dl=1', file_name)
|
||||
print("Downloading done.\n")
|
||||
else:
|
||||
print(f"{file_name} has already been downloaded. Did not download twice.\n")
|
||||
@@ -0,0 +1,38 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import zipfile
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
|
||||
file_name = 'collected_database.pt'
|
||||
dir_path = 'data/predictor/processed'
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
file_name = os.path.join(dir_path, file_name)
|
||||
if not os.path.exists(file_name):
|
||||
print(f"Downloading predictor {file_name}\n")
|
||||
download_file('https://www.dropbox.com/s/ycm4jaojgswp0zm/ckpt_max_corr.pt?dl=1', file_name)
|
||||
print("Downloading done.\n")
|
||||
else:
|
||||
print(f"{file_name} has already been downloaded. Did not download twice.\n")
|
||||
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import zipfile
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
dir_path = 'data'
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
def get_preprocessed_data(file_name, url):
|
||||
print(f"Downloading {file_name} datasets\n")
|
||||
full_name = os.path.join(dir_path, file_name)
|
||||
download_file(url, full_name)
|
||||
print("Downloading done.\n")
|
||||
|
||||
|
||||
for file_name, url in [
|
||||
('imgnet32bylabel.pt', 'https://www.dropbox.com/s/7r3hpugql8qgi9d/imgnet32bylabel.pt?dl=1'),
|
||||
('aircraft100bylabel.pt', 'https://www.dropbox.com/s/nn6mlrk1jijg108/aircraft100bylabel.pt?dl=1'),
|
||||
('cifar100bylabel.pt', 'https://www.dropbox.com/s/y0xahxgzj29kffk/cifar100bylabel.pt?dl=1'),
|
||||
('cifar10bylabel.pt', 'https://www.dropbox.com/s/wt1pcwi991xyhwr/cifar10bylabel.pt?dl=1'),
|
||||
('imgnet32bylabel.pt', 'https://www.dropbox.com/s/7r3hpugql8qgi9d/imgnet32bylabel.pt?dl=1'),
|
||||
('petsbylabel.pt', 'https://www.dropbox.com/s/mxh6qz3grhy7wcn/petsbylabel.pt?dl=1'),
|
||||
('mnistbylabel.pt', 'https://www.dropbox.com/s/86rbuic7a7y34e4/mnistbylabel.pt?dl=1'),
|
||||
('svhnbylabel.pt', 'https://www.dropbox.com/s/yywaelhrsl6egvd/svhnbylabel.pt?dl=1')
|
||||
]:
|
||||
|
||||
get_preprocessed_data(file_name, url)
|
||||
@@ -0,0 +1,149 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def get_meta_train_loader(batch_size, data_path, num_sample, is_pred=False):
|
||||
dataset = MetaTrainDatabase(data_path, num_sample, is_pred)
|
||||
print(f'==> The number of tasks for meta-training: {len(dataset)}')
|
||||
|
||||
loader = DataLoader(dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=1,
|
||||
collate_fn=collate_fn)
|
||||
return loader
|
||||
|
||||
|
||||
def get_meta_test_loader(data_path, data_name, num_class=None, is_pred=False):
|
||||
dataset = MetaTestDataset(data_path, data_name, num_class)
|
||||
print(f'==> Meta-Test dataset {data_name}')
|
||||
|
||||
loader = DataLoader(dataset=dataset,
|
||||
batch_size=100,
|
||||
shuffle=False,
|
||||
num_workers=1)
|
||||
return loader
|
||||
|
||||
|
||||
class MetaTrainDatabase(Dataset):
|
||||
def __init__(self, data_path, num_sample, is_pred=False):
|
||||
self.mode = 'train'
|
||||
self.acc_norm = True
|
||||
self.num_sample = num_sample
|
||||
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
|
||||
|
||||
self.dpath = '{}/{}/processed/'.format(data_path, 'predictor' if is_pred else 'generator')
|
||||
self.dname = f'database_219152_14.0K'
|
||||
|
||||
if not os.path.exists(self.dpath + f'{self.dname}_train.pt'):
|
||||
raise ValueError('')
|
||||
database = torch.load(self.dpath + f'{self.dname}.pt')
|
||||
|
||||
rand_idx = torch.randperm(len(database))
|
||||
test_len = int(len(database) * 0.15)
|
||||
idxlst = {'test': rand_idx[:test_len],
|
||||
'valid': rand_idx[test_len:2 * test_len],
|
||||
'train': rand_idx[2 * test_len:]}
|
||||
|
||||
for m in ['train', 'valid', 'test']:
|
||||
acc, graph, cls, net, flops = [], [], [], [], []
|
||||
for idx in tqdm(idxlst[m].tolist(), desc=f'data-{m}'):
|
||||
acc.append(database[idx]['top1'])
|
||||
net.append(database[idx]['net'])
|
||||
cls.append(database[idx]['class'])
|
||||
flops.append(database[idx]['flops'])
|
||||
if m == 'train':
|
||||
mean = torch.mean(torch.tensor(acc)).item()
|
||||
std = torch.std(torch.tensor(acc)).item()
|
||||
torch.save({'acc': acc,
|
||||
'class': cls,
|
||||
'net': net,
|
||||
'flops': flops,
|
||||
'mean': mean,
|
||||
'std': std},
|
||||
self.dpath + f'{self.dname}_{m}.pt')
|
||||
|
||||
self.set_mode(self.mode)
|
||||
|
||||
def set_mode(self, mode):
|
||||
self.mode = mode
|
||||
data = torch.load(self.dpath + f'{self.dname}_{self.mode}.pt')
|
||||
self.acc = data['acc']
|
||||
self.cls = data['class']
|
||||
self.net = data['net']
|
||||
self.flops = data['flops']
|
||||
self.mean = data['mean']
|
||||
self.std = data['std']
|
||||
|
||||
def __len__(self):
|
||||
return len(self.acc)
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = []
|
||||
classes = self.cls[index]
|
||||
acc = self.acc[index]
|
||||
graph = self.net[index]
|
||||
|
||||
for i, cls in enumerate(classes):
|
||||
cx = self.x[cls.item()][0]
|
||||
ridx = torch.randperm(len(cx))
|
||||
data.append(cx[ridx[:self.num_sample]])
|
||||
x = torch.cat(data)
|
||||
if self.acc_norm:
|
||||
acc = ((acc - self.mean) / self.std) / 100.0
|
||||
else:
|
||||
acc = acc / 100.0
|
||||
return x, graph, torch.tensor(acc).view(1, 1)
|
||||
|
||||
|
||||
class MetaTestDataset(Dataset):
|
||||
def __init__(self, data_path, data_name, num_sample, num_class=None):
|
||||
self.num_sample = num_sample
|
||||
self.data_name = data_name
|
||||
if data_name == 'aircraft':
|
||||
data_name = 'aircraft100'
|
||||
num_class_dict = {
|
||||
'cifar100': 100,
|
||||
'cifar10': 10,
|
||||
'mnist': 10,
|
||||
'aircraft100': 30,
|
||||
'svhn': 10,
|
||||
'pets': 37
|
||||
}
|
||||
# 'aircraft30': 30,
|
||||
# 'aircraft100': 100,
|
||||
|
||||
if num_class is not None:
|
||||
self.num_class = num_class
|
||||
else:
|
||||
self.num_class = num_class_dict[data_name]
|
||||
|
||||
self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt'))
|
||||
|
||||
def __len__(self):
|
||||
return 1000000
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = []
|
||||
classes = list(range(self.num_class))
|
||||
for cls in classes:
|
||||
cx = self.x[cls][0]
|
||||
ridx = torch.randperm(len(cx))
|
||||
data.append(cx[ridx[:self.num_sample]])
|
||||
x = torch.cat(data)
|
||||
return x
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
# x = torch.stack([item[0] for item in batch])
|
||||
# graph = [item[1] for item in batch]
|
||||
# acc = torch.stack([item[2] for item in batch])
|
||||
return batch
|
||||
@@ -0,0 +1,48 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from parser import get_parser
|
||||
from generator import Generator
|
||||
from predictor import Predictor
|
||||
|
||||
def main():
|
||||
args = get_parser()
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
||||
args.device = torch.device("cuda:0")
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
|
||||
if not os.path.exists(args.save_path):
|
||||
os.makedirs(args.save_path)
|
||||
args.model_path = os.path.join(args.save_path, args.model_name, 'model')
|
||||
if not os.path.exists(args.model_path):
|
||||
os.makedirs(args.model_path)
|
||||
|
||||
if args.model_name == 'generator':
|
||||
g = Generator(args)
|
||||
if args.test:
|
||||
args.model_path = os.path.join(args.save_path, 'predictor', 'model')
|
||||
hs = args.hs
|
||||
args.hs = 512
|
||||
p = Predictor(args)
|
||||
args.model_path = os.path.join(args.save_path, args.model_name, 'model')
|
||||
args.hs = hs
|
||||
g.meta_test(p)
|
||||
else:
|
||||
g.meta_train()
|
||||
elif args.model_name == 'predictor':
|
||||
p = Predictor(args)
|
||||
p.meta_train()
|
||||
else:
|
||||
raise ValueError('You should select generator|predictor|train_arch')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,344 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import time
|
||||
import igraph
|
||||
import random
|
||||
import numpy as np
|
||||
import scipy.stats
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
|
||||
def load_graph_config(graph_data_name, nvt, data_path):
|
||||
max_n=20
|
||||
graph_config = {}
|
||||
graph_config['num_vertex_type'] = nvt + 2 # original types + start/end types
|
||||
graph_config['max_n'] = max_n + 2 # maximum number of nodes
|
||||
graph_config['START_TYPE'] = 0 # predefined start vertex type
|
||||
graph_config['END_TYPE'] = 1 # predefined end vertex type
|
||||
|
||||
return graph_config
|
||||
|
||||
|
||||
type_dict = {'2-3-3': 0, '2-3-4': 1, '2-3-6': 2,
|
||||
'2-5-3': 3, '2-5-4': 4, '2-5-6': 5,
|
||||
'2-7-3': 6, '2-7-4': 7, '2-7-6': 8,
|
||||
'3-3-3': 9, '3-3-4': 10, '3-3-6': 11,
|
||||
'3-5-3': 12, '3-5-4': 13, '3-5-6': 14,
|
||||
'3-7-3': 15, '3-7-4': 16, '3-7-6': 17,
|
||||
'4-3-3': 18, '4-3-4': 19, '4-3-6': 20,
|
||||
'4-5-3': 21, '4-5-4': 22, '4-5-6': 23,
|
||||
'4-7-3': 24, '4-7-4': 25, '4-7-6': 26}
|
||||
|
||||
edge_dict = {2: (2, 3, 3), 3: (2, 3, 4), 4: (2, 3, 6),
|
||||
5: (2, 5, 3), 6: (2, 5, 4), 7: (2, 5, 6),
|
||||
8: (2, 7, 3), 9: (2, 7, 4), 10: (2, 7, 6),
|
||||
11: (3, 3, 3), 12: (3, 3, 4), 13: (3, 3, 6),
|
||||
14: (3, 5, 3), 15: (3, 5, 4), 16: (3, 5, 6),
|
||||
17: (3, 7, 3), 18: (3, 7, 4), 19: (3, 7, 6),
|
||||
20: (4, 3, 3), 21: (4, 3, 4), 22: (4, 3, 6),
|
||||
23: (4, 5, 3), 24: (4, 5, 4), 25: (4, 5, 6),
|
||||
26: (4, 7, 3), 27: (4, 7, 4), 28: (4, 7, 6)}
|
||||
|
||||
|
||||
def decode_ofa_mbv3_to_igraph(matrix):
|
||||
# 5 stages, 4 layers for each stage
|
||||
# d: 2, 3, 4
|
||||
# e: 3, 4, 6
|
||||
# k: 3, 5, 7
|
||||
|
||||
# stage_depth to one hot
|
||||
num_stage = 5
|
||||
num_layer = 4
|
||||
|
||||
node_types = torch.zeros(num_stage * num_layer)
|
||||
|
||||
d = []
|
||||
for i in range(num_stage):
|
||||
for j in range(num_layer):
|
||||
d.append(matrix['d'][i])
|
||||
for i, (ks, e, d) in enumerate(zip(
|
||||
matrix['ks'], matrix['e'], d)):
|
||||
node_types[i] = type_dict[f'{d}-{ks}-{e}']
|
||||
|
||||
n = num_stage * num_layer
|
||||
g = igraph.Graph(directed=True)
|
||||
g.add_vertices(n + 2) # + in/out nodes
|
||||
g.vs[0]['type'] = 0
|
||||
for i, v in enumerate(node_types):
|
||||
g.vs[i + 1]['type'] = v + 2 # in node: 0, out node: 1
|
||||
g.add_edge(i, i + 1)
|
||||
g.vs[n + 1]['type'] = 1
|
||||
g.add_edge(n, n + 1)
|
||||
return g, n + 2
|
||||
|
||||
|
||||
def decode_ofa_mbv3_str_to_igraph(gen_str):
|
||||
# 5 stages, 4 layers for each stage
|
||||
# d: 2, 3, 4
|
||||
# e: 3, 4, 6
|
||||
# k: 3, 5, 7
|
||||
|
||||
# stage_depth to one hot
|
||||
num_stage = 5
|
||||
num_layer = 4
|
||||
|
||||
node_types = torch.zeros(num_stage * num_layer)
|
||||
|
||||
d = []
|
||||
split_str = gen_str.split('_')
|
||||
for i, s in enumerate(split_str):
|
||||
if s == '0-0-0':
|
||||
node_types[i] = random.randint(0, 26)
|
||||
else:
|
||||
node_types[i] = type_dict[s]
|
||||
|
||||
n = num_stage * num_layer
|
||||
g = igraph.Graph(directed=True)
|
||||
g.add_vertices(n + 2) # + in/out nodes
|
||||
g.vs[0]['type'] = 0
|
||||
for i, v in enumerate(node_types):
|
||||
g.vs[i + 1]['type'] = v + 2 # in node: 0, out node: 1
|
||||
g.add_edge(i, i + 1)
|
||||
g.vs[n + 1]['type'] = 1
|
||||
g.add_edge(n, n + 1)
|
||||
return g
|
||||
|
||||
|
||||
def is_valid_ofa_mbv3(g, START_TYPE=0, END_TYPE=1):
|
||||
# first need to be a valid DAG computation graph
|
||||
msg = ''
|
||||
res = is_valid_DAG(g, START_TYPE, END_TYPE)
|
||||
# in addition, node i must connect to node i+1
|
||||
res = res and len(g.vs['type']) == 22
|
||||
if not res:
|
||||
return res
|
||||
msg += '{} ({}) '.format(g.vs['type'][1:-1], len(g.vs['type']))
|
||||
|
||||
for i in range(5):
|
||||
if ((g.vs['type'][1:-1][i * 4]) - 2) // 9 == 0:
|
||||
for j in range(1, 4):
|
||||
res = res and ((g.vs['type'][1:-1][i * 4 + j]) - 2) // 9 == 0
|
||||
|
||||
elif ((g.vs['type'][1:-1][i * 4]) - 2) // 9 == 1:
|
||||
for j in range(1, 4):
|
||||
res = res and ((g.vs['type'][1:-1][i * 4 + j]) - 2) // 9 == 1
|
||||
|
||||
elif ((g.vs['type'][1:-1][i * 4]) - 2) // 9 == 2:
|
||||
for j in range(1, 4):
|
||||
res = res and ((g.vs['type'][1:-1][i * 4 + j]) - 2) // 9 == 2
|
||||
else:
|
||||
raise ValueError
|
||||
return res
|
||||
|
||||
|
||||
def is_valid_DAG(g, START_TYPE=0, END_TYPE=1):
|
||||
res = g.is_dag()
|
||||
n_start, n_end = 0, 0
|
||||
for v in g.vs:
|
||||
if v['type'] == START_TYPE:
|
||||
n_start += 1
|
||||
elif v['type'] == END_TYPE:
|
||||
n_end += 1
|
||||
if v.indegree() == 0 and v['type'] != START_TYPE:
|
||||
return False
|
||||
if v.outdegree() == 0 and v['type'] != END_TYPE:
|
||||
return False
|
||||
return res and n_start == 1 and n_end == 1
|
||||
|
||||
|
||||
def decode_igraph_to_ofa_mbv3(g):
|
||||
if not is_valid_ofa_mbv3(g, START_TYPE=0, END_TYPE=1):
|
||||
return None
|
||||
|
||||
graph = {'ks': [], 'e': [], 'd': [4, 4, 4, 4, 4]}
|
||||
for i, edge_type in enumerate(g.vs['type'][1:-1]):
|
||||
edge_type = int(edge_type)
|
||||
d, ks, e = edge_dict[edge_type]
|
||||
graph['ks'].append(ks)
|
||||
graph['e'].append(e)
|
||||
graph['d'][i // 4] = d
|
||||
return graph
|
||||
|
||||
|
||||
class Accumulator():
|
||||
def __init__(self, *args):
|
||||
self.args = args
|
||||
self.argdict = {}
|
||||
for i, arg in enumerate(args):
|
||||
self.argdict[arg] = i
|
||||
self.sums = [0] * len(args)
|
||||
self.cnt = 0
|
||||
|
||||
def accum(self, val):
|
||||
val = [val] if type(val) is not list else val
|
||||
val = [v for v in val if v is not None]
|
||||
assert (len(val) == len(self.args))
|
||||
for i in range(len(val)):
|
||||
if torch.is_tensor(val[i]):
|
||||
val[i] = val[i].item()
|
||||
self.sums[i] += val[i]
|
||||
self.cnt += 1
|
||||
|
||||
def clear(self):
|
||||
self.sums = [0] * len(self.args)
|
||||
self.cnt = 0
|
||||
|
||||
def get(self, arg, avg=True):
|
||||
i = self.argdict.get(arg, -1)
|
||||
assert (i is not -1)
|
||||
if avg:
|
||||
return self.sums[i] / (self.cnt + 1e-8)
|
||||
else:
|
||||
return self.sums[i]
|
||||
|
||||
def print_(self, header=None, time=None,
|
||||
logfile=None, do_not_print=[], as_int=[],
|
||||
avg=True):
|
||||
msg = '' if header is None else header + ': '
|
||||
if time is not None:
|
||||
msg += ('(%.3f secs), ' % time)
|
||||
|
||||
args = [arg for arg in self.args if arg not in do_not_print]
|
||||
arg = []
|
||||
for arg in args:
|
||||
val = self.sums[self.argdict[arg]]
|
||||
if avg:
|
||||
val /= (self.cnt + 1e-8)
|
||||
if arg in as_int:
|
||||
msg += ('%s %d, ' % (arg, int(val)))
|
||||
else:
|
||||
msg += ('%s %.4f, ' % (arg, val))
|
||||
print(msg)
|
||||
|
||||
if logfile is not None:
|
||||
logfile.write(msg + '\n')
|
||||
logfile.flush()
|
||||
|
||||
def add_scalars(self, summary, header=None, tag_scalar=None,
|
||||
step=None, avg=True, args=None):
|
||||
for arg in self.args:
|
||||
val = self.sums[self.argdict[arg]]
|
||||
if avg:
|
||||
val /= (self.cnt + 1e-8)
|
||||
else:
|
||||
val = val
|
||||
tag = f'{header}/{arg}' if header is not None else arg
|
||||
if tag_scalar is not None:
|
||||
summary.add_scalars(main_tag=tag,
|
||||
tag_scalar_dict={tag_scalar: val},
|
||||
global_step=step)
|
||||
else:
|
||||
summary.add_scalar(tag=tag,
|
||||
scalar_value=val,
|
||||
global_step=step)
|
||||
|
||||
|
||||
class Log:
|
||||
def __init__(self, args, logf, summary=None):
|
||||
self.args = args
|
||||
self.logf = logf
|
||||
self.summary = summary
|
||||
self.stime = time.time()
|
||||
self.ep_sttime = None
|
||||
|
||||
def print(self, logger, epoch, tag=None, avg=True):
|
||||
if tag == 'train':
|
||||
ct = time.time() - self.ep_sttime
|
||||
tt = time.time() - self.stime
|
||||
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
|
||||
print(msg)
|
||||
self.logf.write(msg + '\n')
|
||||
logger.print_(header=tag, logfile=self.logf, avg=avg)
|
||||
|
||||
if self.summary is not None:
|
||||
logger.add_scalars(
|
||||
self.summary, header=tag, step=epoch, avg=avg)
|
||||
logger.clear()
|
||||
|
||||
def print_args(self):
|
||||
argdict = vars(self.args)
|
||||
print(argdict)
|
||||
for k, v in argdict.items():
|
||||
self.logf.write(k + ': ' + str(v) + '\n')
|
||||
self.logf.write('\n')
|
||||
|
||||
def set_time(self):
|
||||
self.stime = time.time()
|
||||
|
||||
def save_time_log(self):
|
||||
ct = time.time() - self.stime
|
||||
msg = f'({ct:6.2f}s) meta-training phase done'
|
||||
print(msg)
|
||||
self.logf.write(msg + '\n')
|
||||
|
||||
def print_pred_log(self, loss, corr, tag, epoch=None, max_corr_dict=None):
|
||||
if tag == 'train':
|
||||
ct = time.time() - self.ep_sttime
|
||||
tt = time.time() - self.stime
|
||||
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
|
||||
self.logf.write(msg + '\n');
|
||||
print(msg);
|
||||
self.logf.flush()
|
||||
# msg = f'ep {epoch:3d} ep time {time.time() - ep_sttime:8.2f} '
|
||||
# msg += f'time {time.time() - sttime:6.2f} '
|
||||
if max_corr_dict is not None:
|
||||
max_corr = max_corr_dict['corr']
|
||||
max_loss = max_corr_dict['loss']
|
||||
msg = f'{tag}: loss {loss:.6f} ({max_loss:.6f}) '
|
||||
msg += f'corr {corr:.4f} ({max_corr:.4f})'
|
||||
else:
|
||||
msg = f'{tag}: loss {loss:.6f} corr {corr:.4f}'
|
||||
self.logf.write(msg + '\n');
|
||||
print(msg);
|
||||
self.logf.flush()
|
||||
|
||||
def max_corr_log(self, max_corr_dict):
|
||||
corr = max_corr_dict['corr']
|
||||
loss = max_corr_dict['loss']
|
||||
epoch = max_corr_dict['epoch']
|
||||
msg = f'[epoch {epoch}] max correlation: {corr:.4f}, loss: {loss:.6f}'
|
||||
self.logf.write(msg + '\n');
|
||||
print(msg);
|
||||
self.logf.flush()
|
||||
|
||||
|
||||
def get_log(epoch, loss, y_pred, y, acc_std, acc_mean, tag='train'):
|
||||
msg = f'[{tag}] Ep {epoch} loss {loss.item() / len(y):0.4f} '
|
||||
msg += f'pacc {y_pred[0]:0.4f}'
|
||||
msg += f'({y_pred[0] * 100.0 * acc_std + acc_mean:0.4f}) '
|
||||
msg += f'acc {y[0]:0.4f}({y[0] * 100 * acc_std + acc_mean:0.4f})'
|
||||
return msg
|
||||
|
||||
|
||||
def load_model(model, model_path, load_epoch=None, load_max_pt=None):
|
||||
if load_max_pt is not None:
|
||||
ckpt_path = os.path.join(model_path, load_max_pt)
|
||||
else:
|
||||
ckpt_path = os.path.join(model_path, f'ckpt_{load_epoch}.pt')
|
||||
|
||||
print(f"==> load checkpoint for MetaD2A predictor: {ckpt_path} ...")
|
||||
model.cpu()
|
||||
model.load_state_dict(torch.load(ckpt_path))
|
||||
|
||||
|
||||
def save_model(epoch, model, model_path, max_corr=None):
|
||||
print("==> save current model...")
|
||||
if max_corr is not None:
|
||||
torch.save(model.cpu().state_dict(),
|
||||
os.path.join(model_path, 'ckpt_max_corr.pt'))
|
||||
else:
|
||||
torch.save(model.cpu().state_dict(),
|
||||
os.path.join(model_path, f'ckpt_{epoch}.pt'))
|
||||
|
||||
|
||||
def mean_confidence_interval(data, confidence=0.95):
|
||||
a = 1.0 * np.array(data)
|
||||
n = len(a)
|
||||
m, se = np.mean(a), scipy.stats.sem(a)
|
||||
h = se * scipy.stats.t.ppf((1 + confidence) / 2., n - 1)
|
||||
return m, h
|
||||
@@ -0,0 +1,5 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from .imagenet import *
|
||||
@@ -0,0 +1,56 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
__all__ = ['DataProvider']
|
||||
|
||||
|
||||
class DataProvider:
|
||||
SUB_SEED = 937162211 # random seed for sampling subset
|
||||
VALID_SEED = 2147483647 # random seed for the validation set
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
""" Return name of the dataset """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
""" Return shape as python list of one data entry """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
""" Return `int` of num classes """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
""" local path to save the data """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
""" link to download the data """
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def random_sample_valid_set(train_size, valid_size):
|
||||
assert train_size > valid_size
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.VALID_SEED) # set random seed before sampling validation set
|
||||
rand_indexes = torch.randperm(train_size, generator=g).tolist()
|
||||
|
||||
valid_indexes = rand_indexes[:valid_size]
|
||||
train_indexes = rand_indexes[valid_size:]
|
||||
return train_indexes, valid_indexes
|
||||
|
||||
@staticmethod
|
||||
def labels_to_one_hot(n_classes, labels):
|
||||
new_labels = np.zeros((labels.shape[0], n_classes), dtype=np.float32)
|
||||
new_labels[range(labels.shape[0]), labels] = np.ones(labels.shape)
|
||||
return new_labels
|
||||
@@ -0,0 +1,225 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import warnings
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
from .base_provider import DataProvider
|
||||
from ofa_local.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
__all__ = ['ImagenetDataProvider']
|
||||
|
||||
|
||||
class ImagenetDataProvider(DataProvider):
|
||||
DEFAULT_PATH = '/dataset/imagenet'
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = 'None' if distort_color is None else distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
from ofa.utils.my_dataloader import MyDataLoader
|
||||
assert isinstance(self.image_size, list)
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size) # active resolution for test
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_dataset = self.train_dataset(self.build_train_transform())
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, True, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, True, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'imagenet'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 1000
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = self.DEFAULT_PATH
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = os.path.expanduser('~/dataset/imagenet')
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
return datasets.ImageFolder(self.train_path, _transforms)
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
return datasets.ImageFolder(self.valid_path, _transforms)
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train')
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return os.path.join(self.save_path, 'val')
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
# random_resize_crop -> random_horizontal_flip
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
|
||||
# color augmentation (optional)
|
||||
color_transform = None
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting BN running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, True, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
@@ -0,0 +1,6 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from .dynamic_layers import *
|
||||
from .dynamic_op import *
|
||||
@@ -0,0 +1,632 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
|
||||
from ofa_local.utils.layers import MBConvLayer, ConvLayer, IdentityLayer, set_layer_from_config
|
||||
from ofa_local.utils.layers import ResNetBottleneckBlock, LinearLayer
|
||||
from ofa_local.utils import MyModule, val2list, get_net_device, build_activation, make_divisible, SEModule, MyNetwork
|
||||
from .dynamic_op import DynamicSeparableConv2d, DynamicConv2d, DynamicBatchNorm2d, DynamicSE, DynamicGroupNorm
|
||||
from .dynamic_op import DynamicLinear
|
||||
|
||||
__all__ = [
|
||||
'adjust_bn_according_to_idx', 'copy_bn',
|
||||
'DynamicMBConvLayer', 'DynamicConvLayer', 'DynamicLinearLayer', 'DynamicResNetBottleneckBlock'
|
||||
]
|
||||
|
||||
|
||||
def adjust_bn_according_to_idx(bn, idx):
|
||||
bn.weight.data = torch.index_select(bn.weight.data, 0, idx)
|
||||
bn.bias.data = torch.index_select(bn.bias.data, 0, idx)
|
||||
if type(bn) in [nn.BatchNorm1d, nn.BatchNorm2d]:
|
||||
bn.running_mean.data = torch.index_select(bn.running_mean.data, 0, idx)
|
||||
bn.running_var.data = torch.index_select(bn.running_var.data, 0, idx)
|
||||
|
||||
|
||||
def copy_bn(target_bn, src_bn):
|
||||
feature_dim = target_bn.num_channels if isinstance(target_bn, nn.GroupNorm) else target_bn.num_features
|
||||
|
||||
target_bn.weight.data.copy_(src_bn.weight.data[:feature_dim])
|
||||
target_bn.bias.data.copy_(src_bn.bias.data[:feature_dim])
|
||||
if type(src_bn) in [nn.BatchNorm1d, nn.BatchNorm2d]:
|
||||
target_bn.running_mean.data.copy_(src_bn.running_mean.data[:feature_dim])
|
||||
target_bn.running_var.data.copy_(src_bn.running_var.data[:feature_dim])
|
||||
|
||||
|
||||
class DynamicLinearLayer(MyModule):
|
||||
|
||||
def __init__(self, in_features_list, out_features, bias=True, dropout_rate=0):
|
||||
super(DynamicLinearLayer, self).__init__()
|
||||
|
||||
self.in_features_list = in_features_list
|
||||
self.out_features = out_features
|
||||
self.bias = bias
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
if self.dropout_rate > 0:
|
||||
self.dropout = nn.Dropout(self.dropout_rate, inplace=True)
|
||||
else:
|
||||
self.dropout = None
|
||||
self.linear = DynamicLinear(
|
||||
max_in_features=max(self.in_features_list), max_out_features=self.out_features, bias=self.bias
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
return self.linear(x)
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
return 'DyLinear(%d, %d)' % (max(self.in_features_list), self.out_features)
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': DynamicLinear.__name__,
|
||||
'in_features_list': self.in_features_list,
|
||||
'out_features': self.out_features,
|
||||
'bias': self.bias,
|
||||
'dropout_rate': self.dropout_rate,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
return DynamicLinearLayer(**config)
|
||||
|
||||
def get_active_subnet(self, in_features, preserve_weight=True):
|
||||
sub_layer = LinearLayer(in_features, self.out_features, self.bias, dropout_rate=self.dropout_rate)
|
||||
sub_layer = sub_layer.to(get_net_device(self))
|
||||
if not preserve_weight:
|
||||
return sub_layer
|
||||
|
||||
sub_layer.linear.weight.data.copy_(
|
||||
self.linear.get_active_weight(self.out_features, in_features).data
|
||||
)
|
||||
if self.bias:
|
||||
sub_layer.linear.bias.data.copy_(
|
||||
self.linear.get_active_bias(self.out_features).data
|
||||
)
|
||||
return sub_layer
|
||||
|
||||
def get_active_subnet_config(self, in_features):
|
||||
return {
|
||||
'name': LinearLayer.__name__,
|
||||
'in_features': in_features,
|
||||
'out_features': self.out_features,
|
||||
'bias': self.bias,
|
||||
'dropout_rate': self.dropout_rate,
|
||||
}
|
||||
|
||||
|
||||
class DynamicMBConvLayer(MyModule):
|
||||
|
||||
def __init__(self, in_channel_list, out_channel_list,
|
||||
kernel_size_list=3, expand_ratio_list=6, stride=1, act_func='relu6', use_se=False):
|
||||
super(DynamicMBConvLayer, self).__init__()
|
||||
|
||||
self.in_channel_list = in_channel_list
|
||||
self.out_channel_list = out_channel_list
|
||||
|
||||
self.kernel_size_list = val2list(kernel_size_list)
|
||||
self.expand_ratio_list = val2list(expand_ratio_list)
|
||||
|
||||
self.stride = stride
|
||||
self.act_func = act_func
|
||||
self.use_se = use_se
|
||||
|
||||
# build modules
|
||||
max_middle_channel = make_divisible(
|
||||
round(max(self.in_channel_list) * max(self.expand_ratio_list)), MyNetwork.CHANNEL_DIVISIBLE)
|
||||
if max(self.expand_ratio_list) == 1:
|
||||
self.inverted_bottleneck = None
|
||||
else:
|
||||
self.inverted_bottleneck = nn.Sequential(OrderedDict([
|
||||
('conv', DynamicConv2d(max(self.in_channel_list), max_middle_channel)),
|
||||
('bn', DynamicBatchNorm2d(max_middle_channel)),
|
||||
('act', build_activation(self.act_func)),
|
||||
]))
|
||||
|
||||
self.depth_conv = nn.Sequential(OrderedDict([
|
||||
('conv', DynamicSeparableConv2d(max_middle_channel, self.kernel_size_list, self.stride)),
|
||||
('bn', DynamicBatchNorm2d(max_middle_channel)),
|
||||
('act', build_activation(self.act_func))
|
||||
]))
|
||||
if self.use_se:
|
||||
self.depth_conv.add_module('se', DynamicSE(max_middle_channel))
|
||||
|
||||
self.point_linear = nn.Sequential(OrderedDict([
|
||||
('conv', DynamicConv2d(max_middle_channel, max(self.out_channel_list))),
|
||||
('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
|
||||
]))
|
||||
|
||||
self.active_kernel_size = max(self.kernel_size_list)
|
||||
self.active_expand_ratio = max(self.expand_ratio_list)
|
||||
self.active_out_channel = max(self.out_channel_list)
|
||||
|
||||
def forward(self, x):
|
||||
in_channel = x.size(1)
|
||||
|
||||
if self.inverted_bottleneck is not None:
|
||||
self.inverted_bottleneck.conv.active_out_channel = \
|
||||
make_divisible(round(in_channel * self.active_expand_ratio), MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
self.depth_conv.conv.active_kernel_size = self.active_kernel_size
|
||||
self.point_linear.conv.active_out_channel = self.active_out_channel
|
||||
|
||||
if self.inverted_bottleneck is not None:
|
||||
x = self.inverted_bottleneck(x)
|
||||
x = self.depth_conv(x)
|
||||
x = self.point_linear(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
if self.use_se:
|
||||
return 'SE(O%d, E%.1f, K%d)' % (self.active_out_channel, self.active_expand_ratio, self.active_kernel_size)
|
||||
else:
|
||||
return '(O%d, E%.1f, K%d)' % (self.active_out_channel, self.active_expand_ratio, self.active_kernel_size)
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': DynamicMBConvLayer.__name__,
|
||||
'in_channel_list': self.in_channel_list,
|
||||
'out_channel_list': self.out_channel_list,
|
||||
'kernel_size_list': self.kernel_size_list,
|
||||
'expand_ratio_list': self.expand_ratio_list,
|
||||
'stride': self.stride,
|
||||
'act_func': self.act_func,
|
||||
'use_se': self.use_se,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
return DynamicMBConvLayer(**config)
|
||||
|
||||
############################################################################################
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
return max(self.in_channel_list)
|
||||
|
||||
@property
|
||||
def out_channels(self):
|
||||
return max(self.out_channel_list)
|
||||
|
||||
def active_middle_channel(self, in_channel):
|
||||
return make_divisible(round(in_channel * self.active_expand_ratio), MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
############################################################################################
|
||||
|
||||
def get_active_subnet(self, in_channel, preserve_weight=True):
|
||||
# build the new layer
|
||||
sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
|
||||
sub_layer = sub_layer.to(get_net_device(self))
|
||||
if not preserve_weight:
|
||||
return sub_layer
|
||||
|
||||
middle_channel = self.active_middle_channel(in_channel)
|
||||
# copy weight from current layer
|
||||
if sub_layer.inverted_bottleneck is not None:
|
||||
sub_layer.inverted_bottleneck.conv.weight.data.copy_(
|
||||
self.inverted_bottleneck.conv.get_active_filter(middle_channel, in_channel).data,
|
||||
)
|
||||
copy_bn(sub_layer.inverted_bottleneck.bn, self.inverted_bottleneck.bn.bn)
|
||||
|
||||
sub_layer.depth_conv.conv.weight.data.copy_(
|
||||
self.depth_conv.conv.get_active_filter(middle_channel, self.active_kernel_size).data
|
||||
)
|
||||
copy_bn(sub_layer.depth_conv.bn, self.depth_conv.bn.bn)
|
||||
|
||||
if self.use_se:
|
||||
se_mid = make_divisible(middle_channel // SEModule.REDUCTION, divisor=MyNetwork.CHANNEL_DIVISIBLE)
|
||||
sub_layer.depth_conv.se.fc.reduce.weight.data.copy_(
|
||||
self.depth_conv.se.get_active_reduce_weight(se_mid, middle_channel).data
|
||||
)
|
||||
sub_layer.depth_conv.se.fc.reduce.bias.data.copy_(
|
||||
self.depth_conv.se.get_active_reduce_bias(se_mid).data
|
||||
)
|
||||
|
||||
sub_layer.depth_conv.se.fc.expand.weight.data.copy_(
|
||||
self.depth_conv.se.get_active_expand_weight(se_mid, middle_channel).data
|
||||
)
|
||||
sub_layer.depth_conv.se.fc.expand.bias.data.copy_(
|
||||
self.depth_conv.se.get_active_expand_bias(middle_channel).data
|
||||
)
|
||||
|
||||
sub_layer.point_linear.conv.weight.data.copy_(
|
||||
self.point_linear.conv.get_active_filter(self.active_out_channel, middle_channel).data
|
||||
)
|
||||
copy_bn(sub_layer.point_linear.bn, self.point_linear.bn.bn)
|
||||
|
||||
return sub_layer
|
||||
|
||||
def get_active_subnet_config(self, in_channel):
|
||||
return {
|
||||
'name': MBConvLayer.__name__,
|
||||
'in_channels': in_channel,
|
||||
'out_channels': self.active_out_channel,
|
||||
'kernel_size': self.active_kernel_size,
|
||||
'stride': self.stride,
|
||||
'expand_ratio': self.active_expand_ratio,
|
||||
'mid_channels': self.active_middle_channel(in_channel),
|
||||
'act_func': self.act_func,
|
||||
'use_se': self.use_se,
|
||||
}
|
||||
|
||||
def re_organize_middle_weights(self, expand_ratio_stage=0):
|
||||
importance = torch.sum(torch.abs(self.point_linear.conv.conv.weight.data), dim=(0, 2, 3))
|
||||
if isinstance(self.depth_conv.bn, DynamicGroupNorm):
|
||||
channel_per_group = self.depth_conv.bn.channel_per_group
|
||||
importance_chunks = torch.split(importance, channel_per_group)
|
||||
for chunk in importance_chunks:
|
||||
chunk.data.fill_(torch.mean(chunk))
|
||||
importance = torch.cat(importance_chunks, dim=0)
|
||||
if expand_ratio_stage > 0:
|
||||
sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
|
||||
sorted_expand_list.sort(reverse=True)
|
||||
target_width_list = [
|
||||
make_divisible(round(max(self.in_channel_list) * expand), MyNetwork.CHANNEL_DIVISIBLE)
|
||||
for expand in sorted_expand_list
|
||||
]
|
||||
|
||||
right = len(importance)
|
||||
base = - len(target_width_list) * 1e5
|
||||
for i in range(expand_ratio_stage + 1):
|
||||
left = target_width_list[i]
|
||||
importance[left:right] += base
|
||||
base += 1e5
|
||||
right = left
|
||||
|
||||
sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
|
||||
self.point_linear.conv.conv.weight.data = torch.index_select(
|
||||
self.point_linear.conv.conv.weight.data, 1, sorted_idx
|
||||
)
|
||||
|
||||
adjust_bn_according_to_idx(self.depth_conv.bn.bn, sorted_idx)
|
||||
self.depth_conv.conv.conv.weight.data = torch.index_select(
|
||||
self.depth_conv.conv.conv.weight.data, 0, sorted_idx
|
||||
)
|
||||
|
||||
if self.use_se:
|
||||
# se expand: output dim 0 reorganize
|
||||
se_expand = self.depth_conv.se.fc.expand
|
||||
se_expand.weight.data = torch.index_select(se_expand.weight.data, 0, sorted_idx)
|
||||
se_expand.bias.data = torch.index_select(se_expand.bias.data, 0, sorted_idx)
|
||||
# se reduce: input dim 1 reorganize
|
||||
se_reduce = self.depth_conv.se.fc.reduce
|
||||
se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 1, sorted_idx)
|
||||
# middle weight reorganize
|
||||
se_importance = torch.sum(torch.abs(se_expand.weight.data), dim=(0, 2, 3))
|
||||
se_importance, se_idx = torch.sort(se_importance, dim=0, descending=True)
|
||||
|
||||
se_expand.weight.data = torch.index_select(se_expand.weight.data, 1, se_idx)
|
||||
se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 0, se_idx)
|
||||
se_reduce.bias.data = torch.index_select(se_reduce.bias.data, 0, se_idx)
|
||||
|
||||
if self.inverted_bottleneck is not None:
|
||||
adjust_bn_according_to_idx(self.inverted_bottleneck.bn.bn, sorted_idx)
|
||||
self.inverted_bottleneck.conv.conv.weight.data = torch.index_select(
|
||||
self.inverted_bottleneck.conv.conv.weight.data, 0, sorted_idx
|
||||
)
|
||||
return None
|
||||
else:
|
||||
return sorted_idx
|
||||
|
||||
|
||||
class DynamicConvLayer(MyModule):
|
||||
|
||||
def __init__(self, in_channel_list, out_channel_list, kernel_size=3, stride=1, dilation=1,
|
||||
use_bn=True, act_func='relu6'):
|
||||
super(DynamicConvLayer, self).__init__()
|
||||
|
||||
self.in_channel_list = in_channel_list
|
||||
self.out_channel_list = out_channel_list
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.use_bn = use_bn
|
||||
self.act_func = act_func
|
||||
|
||||
self.conv = DynamicConv2d(
|
||||
max_in_channels=max(self.in_channel_list), max_out_channels=max(self.out_channel_list),
|
||||
kernel_size=self.kernel_size, stride=self.stride, dilation=self.dilation,
|
||||
)
|
||||
if self.use_bn:
|
||||
self.bn = DynamicBatchNorm2d(max(self.out_channel_list))
|
||||
self.act = build_activation(self.act_func)
|
||||
|
||||
self.active_out_channel = max(self.out_channel_list)
|
||||
|
||||
def forward(self, x):
|
||||
self.conv.active_out_channel = self.active_out_channel
|
||||
|
||||
x = self.conv(x)
|
||||
if self.use_bn:
|
||||
x = self.bn(x)
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
return 'DyConv(O%d, K%d, S%d)' % (self.active_out_channel, self.kernel_size, self.stride)
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': DynamicConvLayer.__name__,
|
||||
'in_channel_list': self.in_channel_list,
|
||||
'out_channel_list': self.out_channel_list,
|
||||
'kernel_size': self.kernel_size,
|
||||
'stride': self.stride,
|
||||
'dilation': self.dilation,
|
||||
'use_bn': self.use_bn,
|
||||
'act_func': self.act_func,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
return DynamicConvLayer(**config)
|
||||
|
||||
############################################################################################
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
return max(self.in_channel_list)
|
||||
|
||||
@property
|
||||
def out_channels(self):
|
||||
return max(self.out_channel_list)
|
||||
|
||||
############################################################################################
|
||||
|
||||
def get_active_subnet(self, in_channel, preserve_weight=True):
|
||||
sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
|
||||
sub_layer = sub_layer.to(get_net_device(self))
|
||||
|
||||
if not preserve_weight:
|
||||
return sub_layer
|
||||
|
||||
sub_layer.conv.weight.data.copy_(self.conv.get_active_filter(self.active_out_channel, in_channel).data)
|
||||
if self.use_bn:
|
||||
copy_bn(sub_layer.bn, self.bn.bn)
|
||||
|
||||
return sub_layer
|
||||
|
||||
def get_active_subnet_config(self, in_channel):
|
||||
return {
|
||||
'name': ConvLayer.__name__,
|
||||
'in_channels': in_channel,
|
||||
'out_channels': self.active_out_channel,
|
||||
'kernel_size': self.kernel_size,
|
||||
'stride': self.stride,
|
||||
'dilation': self.dilation,
|
||||
'use_bn': self.use_bn,
|
||||
'act_func': self.act_func,
|
||||
}
|
||||
|
||||
|
||||
class DynamicResNetBottleneckBlock(MyModule):
|
||||
|
||||
def __init__(self, in_channel_list, out_channel_list, expand_ratio_list=0.25,
|
||||
kernel_size=3, stride=1, act_func='relu', downsample_mode='avgpool_conv'):
|
||||
super(DynamicResNetBottleneckBlock, self).__init__()
|
||||
|
||||
self.in_channel_list = in_channel_list
|
||||
self.out_channel_list = out_channel_list
|
||||
self.expand_ratio_list = val2list(expand_ratio_list)
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.act_func = act_func
|
||||
self.downsample_mode = downsample_mode
|
||||
|
||||
# build modules
|
||||
max_middle_channel = make_divisible(
|
||||
round(max(self.out_channel_list) * max(self.expand_ratio_list)), MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
self.conv1 = nn.Sequential(OrderedDict([
|
||||
('conv', DynamicConv2d(max(self.in_channel_list), max_middle_channel)),
|
||||
('bn', DynamicBatchNorm2d(max_middle_channel)),
|
||||
('act', build_activation(self.act_func, inplace=True)),
|
||||
]))
|
||||
|
||||
self.conv2 = nn.Sequential(OrderedDict([
|
||||
('conv', DynamicConv2d(max_middle_channel, max_middle_channel, kernel_size, stride)),
|
||||
('bn', DynamicBatchNorm2d(max_middle_channel)),
|
||||
('act', build_activation(self.act_func, inplace=True))
|
||||
]))
|
||||
|
||||
self.conv3 = nn.Sequential(OrderedDict([
|
||||
('conv', DynamicConv2d(max_middle_channel, max(self.out_channel_list))),
|
||||
('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
|
||||
]))
|
||||
|
||||
if self.stride == 1 and self.in_channel_list == self.out_channel_list:
|
||||
self.downsample = IdentityLayer(max(self.in_channel_list), max(self.out_channel_list))
|
||||
elif self.downsample_mode == 'conv':
|
||||
self.downsample = nn.Sequential(OrderedDict([
|
||||
('conv', DynamicConv2d(max(self.in_channel_list), max(self.out_channel_list), stride=stride)),
|
||||
('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
|
||||
]))
|
||||
elif self.downsample_mode == 'avgpool_conv':
|
||||
self.downsample = nn.Sequential(OrderedDict([
|
||||
('avg_pool', nn.AvgPool2d(kernel_size=stride, stride=stride, padding=0, ceil_mode=True)),
|
||||
('conv', DynamicConv2d(max(self.in_channel_list), max(self.out_channel_list))),
|
||||
('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
|
||||
]))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.final_act = build_activation(self.act_func, inplace=True)
|
||||
|
||||
self.active_expand_ratio = max(self.expand_ratio_list)
|
||||
self.active_out_channel = max(self.out_channel_list)
|
||||
|
||||
def forward(self, x):
|
||||
feature_dim = self.active_middle_channels
|
||||
|
||||
self.conv1.conv.active_out_channel = feature_dim
|
||||
self.conv2.conv.active_out_channel = feature_dim
|
||||
self.conv3.conv.active_out_channel = self.active_out_channel
|
||||
if not isinstance(self.downsample, IdentityLayer):
|
||||
self.downsample.conv.active_out_channel = self.active_out_channel
|
||||
|
||||
residual = self.downsample(x)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
|
||||
x = x + residual
|
||||
x = self.final_act(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
return '(%s, %s)' % (
|
||||
'%dx%d_BottleneckConv_in->%d->%d_S%d' % (
|
||||
self.kernel_size, self.kernel_size, self.active_middle_channels, self.active_out_channel, self.stride
|
||||
),
|
||||
'Identity' if isinstance(self.downsample, IdentityLayer) else self.downsample_mode,
|
||||
)
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': DynamicResNetBottleneckBlock.__name__,
|
||||
'in_channel_list': self.in_channel_list,
|
||||
'out_channel_list': self.out_channel_list,
|
||||
'expand_ratio_list': self.expand_ratio_list,
|
||||
'kernel_size': self.kernel_size,
|
||||
'stride': self.stride,
|
||||
'act_func': self.act_func,
|
||||
'downsample_mode': self.downsample_mode,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
return DynamicResNetBottleneckBlock(**config)
|
||||
|
||||
############################################################################################
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
return max(self.in_channel_list)
|
||||
|
||||
@property
|
||||
def out_channels(self):
|
||||
return max(self.out_channel_list)
|
||||
|
||||
@property
|
||||
def active_middle_channels(self):
|
||||
feature_dim = round(self.active_out_channel * self.active_expand_ratio)
|
||||
feature_dim = make_divisible(feature_dim, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
return feature_dim
|
||||
|
||||
############################################################################################
|
||||
|
||||
def get_active_subnet(self, in_channel, preserve_weight=True):
|
||||
# build the new layer
|
||||
sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
|
||||
sub_layer = sub_layer.to(get_net_device(self))
|
||||
if not preserve_weight:
|
||||
return sub_layer
|
||||
|
||||
# copy weight from current layer
|
||||
sub_layer.conv1.conv.weight.data.copy_(
|
||||
self.conv1.conv.get_active_filter(self.active_middle_channels, in_channel).data)
|
||||
copy_bn(sub_layer.conv1.bn, self.conv1.bn.bn)
|
||||
|
||||
sub_layer.conv2.conv.weight.data.copy_(
|
||||
self.conv2.conv.get_active_filter(self.active_middle_channels, self.active_middle_channels).data)
|
||||
copy_bn(sub_layer.conv2.bn, self.conv2.bn.bn)
|
||||
|
||||
sub_layer.conv3.conv.weight.data.copy_(
|
||||
self.conv3.conv.get_active_filter(self.active_out_channel, self.active_middle_channels).data)
|
||||
copy_bn(sub_layer.conv3.bn, self.conv3.bn.bn)
|
||||
|
||||
if not isinstance(self.downsample, IdentityLayer):
|
||||
sub_layer.downsample.conv.weight.data.copy_(
|
||||
self.downsample.conv.get_active_filter(self.active_out_channel, in_channel).data)
|
||||
copy_bn(sub_layer.downsample.bn, self.downsample.bn.bn)
|
||||
|
||||
return sub_layer
|
||||
|
||||
def get_active_subnet_config(self, in_channel):
|
||||
return {
|
||||
'name': ResNetBottleneckBlock.__name__,
|
||||
'in_channels': in_channel,
|
||||
'out_channels': self.active_out_channel,
|
||||
'kernel_size': self.kernel_size,
|
||||
'stride': self.stride,
|
||||
'expand_ratio': self.active_expand_ratio,
|
||||
'mid_channels': self.active_middle_channels,
|
||||
'act_func': self.act_func,
|
||||
'groups': 1,
|
||||
'downsample_mode': self.downsample_mode,
|
||||
}
|
||||
|
||||
def re_organize_middle_weights(self, expand_ratio_stage=0):
|
||||
# conv3 -> conv2
|
||||
importance = torch.sum(torch.abs(self.conv3.conv.conv.weight.data), dim=(0, 2, 3))
|
||||
if isinstance(self.conv2.bn, DynamicGroupNorm):
|
||||
channel_per_group = self.conv2.bn.channel_per_group
|
||||
importance_chunks = torch.split(importance, channel_per_group)
|
||||
for chunk in importance_chunks:
|
||||
chunk.data.fill_(torch.mean(chunk))
|
||||
importance = torch.cat(importance_chunks, dim=0)
|
||||
if expand_ratio_stage > 0:
|
||||
sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
|
||||
sorted_expand_list.sort(reverse=True)
|
||||
target_width_list = [
|
||||
make_divisible(round(max(self.out_channel_list) * expand), MyNetwork.CHANNEL_DIVISIBLE)
|
||||
for expand in sorted_expand_list
|
||||
]
|
||||
right = len(importance)
|
||||
base = - len(target_width_list) * 1e5
|
||||
for i in range(expand_ratio_stage + 1):
|
||||
left = target_width_list[i]
|
||||
importance[left:right] += base
|
||||
base += 1e5
|
||||
right = left
|
||||
|
||||
sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
|
||||
self.conv3.conv.conv.weight.data = torch.index_select(self.conv3.conv.conv.weight.data, 1, sorted_idx)
|
||||
adjust_bn_according_to_idx(self.conv2.bn.bn, sorted_idx)
|
||||
self.conv2.conv.conv.weight.data = torch.index_select(self.conv2.conv.conv.weight.data, 0, sorted_idx)
|
||||
|
||||
# conv2 -> conv1
|
||||
importance = torch.sum(torch.abs(self.conv2.conv.conv.weight.data), dim=(0, 2, 3))
|
||||
if isinstance(self.conv1.bn, DynamicGroupNorm):
|
||||
channel_per_group = self.conv1.bn.channel_per_group
|
||||
importance_chunks = torch.split(importance, channel_per_group)
|
||||
for chunk in importance_chunks:
|
||||
chunk.data.fill_(torch.mean(chunk))
|
||||
importance = torch.cat(importance_chunks, dim=0)
|
||||
if expand_ratio_stage > 0:
|
||||
sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
|
||||
sorted_expand_list.sort(reverse=True)
|
||||
target_width_list = [
|
||||
make_divisible(round(max(self.out_channel_list) * expand), MyNetwork.CHANNEL_DIVISIBLE)
|
||||
for expand in sorted_expand_list
|
||||
]
|
||||
right = len(importance)
|
||||
base = - len(target_width_list) * 1e5
|
||||
for i in range(expand_ratio_stage + 1):
|
||||
left = target_width_list[i]
|
||||
importance[left:right] += base
|
||||
base += 1e5
|
||||
right = left
|
||||
sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
|
||||
|
||||
self.conv2.conv.conv.weight.data = torch.index_select(self.conv2.conv.conv.weight.data, 1, sorted_idx)
|
||||
adjust_bn_according_to_idx(self.conv1.bn.bn, sorted_idx)
|
||||
self.conv1.conv.conv.weight.data = torch.index_select(self.conv1.conv.conv.weight.data, 0, sorted_idx)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,314 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ofa_local.utils import get_same_padding, sub_filter_start_end, make_divisible, SEModule, MyNetwork, MyConv2d
|
||||
|
||||
__all__ = ['DynamicSeparableConv2d', 'DynamicConv2d', 'DynamicGroupConv2d',
|
||||
'DynamicBatchNorm2d', 'DynamicGroupNorm', 'DynamicSE', 'DynamicLinear']
|
||||
|
||||
|
||||
class DynamicSeparableConv2d(nn.Module):
|
||||
KERNEL_TRANSFORM_MODE = 1 # None or 1
|
||||
|
||||
def __init__(self, max_in_channels, kernel_size_list, stride=1, dilation=1):
|
||||
super(DynamicSeparableConv2d, self).__init__()
|
||||
|
||||
self.max_in_channels = max_in_channels
|
||||
self.kernel_size_list = kernel_size_list
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
self.max_in_channels, self.max_in_channels, max(self.kernel_size_list), self.stride,
|
||||
groups=self.max_in_channels, bias=False,
|
||||
)
|
||||
|
||||
self._ks_set = list(set(self.kernel_size_list))
|
||||
self._ks_set.sort() # e.g., [3, 5, 7]
|
||||
if self.KERNEL_TRANSFORM_MODE is not None:
|
||||
# register scaling parameters
|
||||
# 7to5_matrix, 5to3_matrix
|
||||
scale_params = {}
|
||||
for i in range(len(self._ks_set) - 1):
|
||||
ks_small = self._ks_set[i]
|
||||
ks_larger = self._ks_set[i + 1]
|
||||
param_name = '%dto%d' % (ks_larger, ks_small)
|
||||
# noinspection PyArgumentList
|
||||
scale_params['%s_matrix' % param_name] = Parameter(torch.eye(ks_small ** 2))
|
||||
for name, param in scale_params.items():
|
||||
self.register_parameter(name, param)
|
||||
|
||||
self.active_kernel_size = max(self.kernel_size_list)
|
||||
|
||||
def get_active_filter(self, in_channel, kernel_size):
|
||||
out_channel = in_channel
|
||||
max_kernel_size = max(self.kernel_size_list)
|
||||
|
||||
start, end = sub_filter_start_end(max_kernel_size, kernel_size)
|
||||
filters = self.conv.weight[:out_channel, :in_channel, start:end, start:end]
|
||||
if self.KERNEL_TRANSFORM_MODE is not None and kernel_size < max_kernel_size:
|
||||
start_filter = self.conv.weight[:out_channel, :in_channel, :, :] # start with max kernel
|
||||
for i in range(len(self._ks_set) - 1, 0, -1):
|
||||
src_ks = self._ks_set[i]
|
||||
if src_ks <= kernel_size:
|
||||
break
|
||||
target_ks = self._ks_set[i - 1]
|
||||
start, end = sub_filter_start_end(src_ks, target_ks)
|
||||
_input_filter = start_filter[:, :, start:end, start:end]
|
||||
_input_filter = _input_filter.contiguous()
|
||||
_input_filter = _input_filter.view(_input_filter.size(0), _input_filter.size(1), -1)
|
||||
_input_filter = _input_filter.view(-1, _input_filter.size(2))
|
||||
_input_filter = F.linear(
|
||||
_input_filter, self.__getattr__('%dto%d_matrix' % (src_ks, target_ks)),
|
||||
)
|
||||
_input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks ** 2)
|
||||
_input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks, target_ks)
|
||||
start_filter = _input_filter
|
||||
filters = start_filter
|
||||
return filters
|
||||
|
||||
def forward(self, x, kernel_size=None):
|
||||
if kernel_size is None:
|
||||
kernel_size = self.active_kernel_size
|
||||
in_channel = x.size(1)
|
||||
|
||||
filters = self.get_active_filter(in_channel, kernel_size).contiguous()
|
||||
|
||||
padding = get_same_padding(kernel_size)
|
||||
filters = self.conv.weight_standardization(filters) if isinstance(self.conv, MyConv2d) else filters
|
||||
y = F.conv2d(
|
||||
x, filters, None, self.stride, padding, self.dilation, in_channel
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
class DynamicConv2d(nn.Module):
|
||||
|
||||
def __init__(self, max_in_channels, max_out_channels, kernel_size=1, stride=1, dilation=1):
|
||||
super(DynamicConv2d, self).__init__()
|
||||
|
||||
self.max_in_channels = max_in_channels
|
||||
self.max_out_channels = max_out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
self.max_in_channels, self.max_out_channels, self.kernel_size, stride=self.stride, bias=False,
|
||||
)
|
||||
|
||||
self.active_out_channel = self.max_out_channels
|
||||
|
||||
def get_active_filter(self, out_channel, in_channel):
|
||||
return self.conv.weight[:out_channel, :in_channel, :, :]
|
||||
|
||||
def forward(self, x, out_channel=None):
|
||||
if out_channel is None:
|
||||
out_channel = self.active_out_channel
|
||||
in_channel = x.size(1)
|
||||
filters = self.get_active_filter(out_channel, in_channel).contiguous()
|
||||
|
||||
padding = get_same_padding(self.kernel_size)
|
||||
filters = self.conv.weight_standardization(filters) if isinstance(self.conv, MyConv2d) else filters
|
||||
y = F.conv2d(x, filters, None, self.stride, padding, self.dilation, 1)
|
||||
return y
|
||||
|
||||
|
||||
class DynamicGroupConv2d(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size_list, groups_list, stride=1, dilation=1):
|
||||
super(DynamicGroupConv2d, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size_list = kernel_size_list
|
||||
self.groups_list = groups_list
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
self.in_channels, self.out_channels, max(self.kernel_size_list), self.stride,
|
||||
groups=min(self.groups_list), bias=False,
|
||||
)
|
||||
|
||||
self.active_kernel_size = max(self.kernel_size_list)
|
||||
self.active_groups = min(self.groups_list)
|
||||
|
||||
def get_active_filter(self, kernel_size, groups):
|
||||
start, end = sub_filter_start_end(max(self.kernel_size_list), kernel_size)
|
||||
filters = self.conv.weight[:, :, start:end, start:end]
|
||||
|
||||
sub_filters = torch.chunk(filters, groups, dim=0)
|
||||
sub_in_channels = self.in_channels // groups
|
||||
sub_ratio = filters.size(1) // sub_in_channels
|
||||
|
||||
filter_crops = []
|
||||
for i, sub_filter in enumerate(sub_filters):
|
||||
part_id = i % sub_ratio
|
||||
start = part_id * sub_in_channels
|
||||
filter_crops.append(sub_filter[:, start:start + sub_in_channels, :, :])
|
||||
filters = torch.cat(filter_crops, dim=0)
|
||||
return filters
|
||||
|
||||
def forward(self, x, kernel_size=None, groups=None):
|
||||
if kernel_size is None:
|
||||
kernel_size = self.active_kernel_size
|
||||
if groups is None:
|
||||
groups = self.active_groups
|
||||
|
||||
filters = self.get_active_filter(kernel_size, groups).contiguous()
|
||||
padding = get_same_padding(kernel_size)
|
||||
filters = self.conv.weight_standardization(filters) if isinstance(self.conv, MyConv2d) else filters
|
||||
y = F.conv2d(
|
||||
x, filters, None, self.stride, padding, self.dilation, groups,
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
class DynamicBatchNorm2d(nn.Module):
|
||||
SET_RUNNING_STATISTICS = False
|
||||
|
||||
def __init__(self, max_feature_dim):
|
||||
super(DynamicBatchNorm2d, self).__init__()
|
||||
|
||||
self.max_feature_dim = max_feature_dim
|
||||
self.bn = nn.BatchNorm2d(self.max_feature_dim)
|
||||
|
||||
@staticmethod
|
||||
def bn_forward(x, bn: nn.BatchNorm2d, feature_dim):
|
||||
if bn.num_features == feature_dim or DynamicBatchNorm2d.SET_RUNNING_STATISTICS:
|
||||
return bn(x)
|
||||
else:
|
||||
exponential_average_factor = 0.0
|
||||
|
||||
if bn.training and bn.track_running_stats:
|
||||
if bn.num_batches_tracked is not None:
|
||||
bn.num_batches_tracked += 1
|
||||
if bn.momentum is None: # use cumulative moving average
|
||||
exponential_average_factor = 1.0 / float(bn.num_batches_tracked)
|
||||
else: # use exponential moving average
|
||||
exponential_average_factor = bn.momentum
|
||||
return F.batch_norm(
|
||||
x, bn.running_mean[:feature_dim], bn.running_var[:feature_dim], bn.weight[:feature_dim],
|
||||
bn.bias[:feature_dim], bn.training or not bn.track_running_stats,
|
||||
exponential_average_factor, bn.eps,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
feature_dim = x.size(1)
|
||||
y = self.bn_forward(x, self.bn, feature_dim)
|
||||
return y
|
||||
|
||||
|
||||
class DynamicGroupNorm(nn.GroupNorm):
|
||||
|
||||
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, channel_per_group=None):
|
||||
super(DynamicGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
|
||||
self.channel_per_group = channel_per_group
|
||||
|
||||
def forward(self, x):
|
||||
n_channels = x.size(1)
|
||||
n_groups = n_channels // self.channel_per_group
|
||||
return F.group_norm(x, n_groups, self.weight[:n_channels], self.bias[:n_channels], self.eps)
|
||||
|
||||
@property
|
||||
def bn(self):
|
||||
return self
|
||||
|
||||
|
||||
class DynamicSE(SEModule):
|
||||
|
||||
def __init__(self, max_channel):
|
||||
super(DynamicSE, self).__init__(max_channel)
|
||||
|
||||
def get_active_reduce_weight(self, num_mid, in_channel, groups=None):
|
||||
if groups is None or groups == 1:
|
||||
return self.fc.reduce.weight[:num_mid, :in_channel, :, :]
|
||||
else:
|
||||
assert in_channel % groups == 0
|
||||
sub_in_channels = in_channel // groups
|
||||
sub_filters = torch.chunk(self.fc.reduce.weight[:num_mid, :, :, :], groups, dim=1)
|
||||
return torch.cat([
|
||||
sub_filter[:, :sub_in_channels, :, :] for sub_filter in sub_filters
|
||||
], dim=1)
|
||||
|
||||
def get_active_reduce_bias(self, num_mid):
|
||||
return self.fc.reduce.bias[:num_mid] if self.fc.reduce.bias is not None else None
|
||||
|
||||
def get_active_expand_weight(self, num_mid, in_channel, groups=None):
|
||||
if groups is None or groups == 1:
|
||||
return self.fc.expand.weight[:in_channel, :num_mid, :, :]
|
||||
else:
|
||||
assert in_channel % groups == 0
|
||||
sub_in_channels = in_channel // groups
|
||||
sub_filters = torch.chunk(self.fc.expand.weight[:, :num_mid, :, :], groups, dim=0)
|
||||
return torch.cat([
|
||||
sub_filter[:sub_in_channels, :, :, :] for sub_filter in sub_filters
|
||||
], dim=0)
|
||||
|
||||
def get_active_expand_bias(self, in_channel, groups=None):
|
||||
if groups is None or groups == 1:
|
||||
return self.fc.expand.bias[:in_channel] if self.fc.expand.bias is not None else None
|
||||
else:
|
||||
assert in_channel % groups == 0
|
||||
sub_in_channels = in_channel // groups
|
||||
sub_bias_list = torch.chunk(self.fc.expand.bias, groups, dim=0)
|
||||
return torch.cat([
|
||||
sub_bias[:sub_in_channels] for sub_bias in sub_bias_list
|
||||
], dim=0)
|
||||
|
||||
def forward(self, x, groups=None):
|
||||
in_channel = x.size(1)
|
||||
num_mid = make_divisible(in_channel // self.reduction, divisor=MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
y = x.mean(3, keepdim=True).mean(2, keepdim=True)
|
||||
# reduce
|
||||
reduce_filter = self.get_active_reduce_weight(num_mid, in_channel, groups=groups).contiguous()
|
||||
reduce_bias = self.get_active_reduce_bias(num_mid)
|
||||
y = F.conv2d(y, reduce_filter, reduce_bias, 1, 0, 1, 1)
|
||||
# relu
|
||||
y = self.fc.relu(y)
|
||||
# expand
|
||||
expand_filter = self.get_active_expand_weight(num_mid, in_channel, groups=groups).contiguous()
|
||||
expand_bias = self.get_active_expand_bias(in_channel, groups=groups)
|
||||
y = F.conv2d(y, expand_filter, expand_bias, 1, 0, 1, 1)
|
||||
# hard sigmoid
|
||||
y = self.fc.h_sigmoid(y)
|
||||
|
||||
return x * y
|
||||
|
||||
|
||||
class DynamicLinear(nn.Module):
|
||||
|
||||
def __init__(self, max_in_features, max_out_features, bias=True):
|
||||
super(DynamicLinear, self).__init__()
|
||||
|
||||
self.max_in_features = max_in_features
|
||||
self.max_out_features = max_out_features
|
||||
self.bias = bias
|
||||
|
||||
self.linear = nn.Linear(self.max_in_features, self.max_out_features, self.bias)
|
||||
|
||||
self.active_out_features = self.max_out_features
|
||||
|
||||
def get_active_weight(self, out_features, in_features):
|
||||
return self.linear.weight[:out_features, :in_features]
|
||||
|
||||
def get_active_bias(self, out_features):
|
||||
return self.linear.bias[:out_features] if self.bias else None
|
||||
|
||||
def forward(self, x, out_features=None):
|
||||
if out_features is None:
|
||||
out_features = self.active_out_features
|
||||
|
||||
in_features = x.size(1)
|
||||
weight = self.get_active_weight(out_features, in_features).contiguous()
|
||||
bias = self.get_active_bias(out_features)
|
||||
y = F.linear(x, weight, bias)
|
||||
return y
|
||||
@@ -0,0 +1,7 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from .ofa_proxyless import OFAProxylessNASNets
|
||||
from .ofa_mbv3 import OFAMobileNetV3
|
||||
from .ofa_resnets import OFAResNets
|
||||
@@ -0,0 +1,336 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import copy
|
||||
import random
|
||||
|
||||
from ofa_local.imagenet_classification.elastic_nn.modules.dynamic_layers import DynamicMBConvLayer
|
||||
from ofa_local.utils.layers import ConvLayer, IdentityLayer, LinearLayer, MBConvLayer, ResidualBlock
|
||||
from ofa_local.imagenet_classification.networks import MobileNetV3
|
||||
from ofa_local.utils import make_divisible, val2list, MyNetwork
|
||||
from ofa_local.utils.layers import set_layer_from_config
|
||||
import gin
|
||||
|
||||
__all__ = ['OFAMobileNetV3']
|
||||
|
||||
@gin.configurable
|
||||
class OFAMobileNetV3(MobileNetV3):
|
||||
|
||||
def __init__(self, n_classes=1000, bn_param=(0.1, 1e-5), dropout_rate=0.1, base_stage_width=None, width_mult=1.0,
|
||||
ks_list=3, expand_ratio_list=6, depth_list=4, dropblock=False, block_size=0):
|
||||
|
||||
self.width_mult = width_mult
|
||||
self.ks_list = val2list(ks_list, 1)
|
||||
self.expand_ratio_list = val2list(expand_ratio_list, 1)
|
||||
self.depth_list = val2list(depth_list, 1)
|
||||
|
||||
self.ks_list.sort()
|
||||
self.expand_ratio_list.sort()
|
||||
self.depth_list.sort()
|
||||
|
||||
base_stage_width = [16, 16, 24, 40, 80, 112, 160, 960, 1280]
|
||||
|
||||
final_expand_width = make_divisible(base_stage_width[-2] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
last_channel = make_divisible(base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
stride_stages = [1, 2, 2, 2, 1, 2]
|
||||
act_stages = ['relu', 'relu', 'relu', 'h_swish', 'h_swish', 'h_swish']
|
||||
se_stages = [False, False, True, False, True, True]
|
||||
n_block_list = [1] + [max(self.depth_list)] * 5
|
||||
width_list = []
|
||||
for base_width in base_stage_width[:-2]:
|
||||
width = make_divisible(base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
width_list.append(width)
|
||||
|
||||
input_channel, first_block_dim = width_list[0], width_list[1]
|
||||
# first conv layer
|
||||
first_conv = ConvLayer(3, input_channel, kernel_size=3, stride=2, act_func='h_swish')
|
||||
first_block_conv = MBConvLayer(
|
||||
in_channels=input_channel, out_channels=first_block_dim, kernel_size=3, stride=stride_stages[0],
|
||||
expand_ratio=1, act_func=act_stages[0], use_se=se_stages[0],
|
||||
)
|
||||
first_block = ResidualBlock(
|
||||
first_block_conv,
|
||||
IdentityLayer(first_block_dim, first_block_dim) if input_channel == first_block_dim else None,
|
||||
dropout_rate, dropblock, block_size
|
||||
)
|
||||
|
||||
# inverted residual blocks
|
||||
self.block_group_info = []
|
||||
blocks = [first_block]
|
||||
_block_index = 1
|
||||
feature_dim = first_block_dim
|
||||
|
||||
for width, n_block, s, act_func, use_se in zip(width_list[2:], n_block_list[1:],
|
||||
stride_stages[1:], act_stages[1:], se_stages[1:]):
|
||||
self.block_group_info.append([_block_index + i for i in range(n_block)])
|
||||
_block_index += n_block
|
||||
|
||||
output_channel = width
|
||||
for i in range(n_block):
|
||||
if i == 0:
|
||||
stride = s
|
||||
else:
|
||||
stride = 1
|
||||
mobile_inverted_conv = DynamicMBConvLayer(
|
||||
in_channel_list=val2list(feature_dim), out_channel_list=val2list(output_channel),
|
||||
kernel_size_list=ks_list, expand_ratio_list=expand_ratio_list,
|
||||
stride=stride, act_func=act_func, use_se=use_se,
|
||||
)
|
||||
if stride == 1 and feature_dim == output_channel:
|
||||
shortcut = IdentityLayer(feature_dim, feature_dim)
|
||||
else:
|
||||
shortcut = None
|
||||
blocks.append(ResidualBlock(mobile_inverted_conv, shortcut,
|
||||
dropout_rate, dropblock, block_size))
|
||||
feature_dim = output_channel
|
||||
# final expand layer, feature mix layer & classifier
|
||||
final_expand_layer = ConvLayer(feature_dim, final_expand_width, kernel_size=1, act_func='h_swish')
|
||||
feature_mix_layer = ConvLayer(
|
||||
final_expand_width, last_channel, kernel_size=1, bias=False, use_bn=False, act_func='h_swish',
|
||||
)
|
||||
|
||||
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
||||
|
||||
super(OFAMobileNetV3, self).__init__(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
|
||||
|
||||
# set bn param
|
||||
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])
|
||||
|
||||
# runtime_depth
|
||||
self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]
|
||||
|
||||
""" MyNetwork required methods """
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'OFAMobileNetV3'
|
||||
|
||||
def forward(self, x):
|
||||
# first conv
|
||||
x = self.first_conv(x)
|
||||
# first block
|
||||
x = self.blocks[0](x)
|
||||
# blocks
|
||||
for stage_id, block_idx in enumerate(self.block_group_info):
|
||||
depth = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:depth]
|
||||
for idx in active_idx:
|
||||
x = self.blocks[idx](x)
|
||||
x = self.final_expand_layer(x)
|
||||
x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling
|
||||
x = self.feature_mix_layer(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
_str = self.first_conv.module_str + '\n'
|
||||
_str += self.blocks[0].module_str + '\n'
|
||||
|
||||
for stage_id, block_idx in enumerate(self.block_group_info):
|
||||
depth = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:depth]
|
||||
for idx in active_idx:
|
||||
_str += self.blocks[idx].module_str + '\n'
|
||||
|
||||
_str += self.final_expand_layer.module_str + '\n'
|
||||
_str += self.feature_mix_layer.module_str + '\n'
|
||||
_str += self.classifier.module_str + '\n'
|
||||
return _str
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': OFAMobileNetV3.__name__,
|
||||
'bn': self.get_bn_param(),
|
||||
'first_conv': self.first_conv.config,
|
||||
'blocks': [
|
||||
block.config for block in self.blocks
|
||||
],
|
||||
'final_expand_layer': self.final_expand_layer.config,
|
||||
'feature_mix_layer': self.feature_mix_layer.config,
|
||||
'classifier': self.classifier.config,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
raise ValueError('do not support this function')
|
||||
|
||||
@property
|
||||
def grouped_block_index(self):
|
||||
return self.block_group_info
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
model_dict = self.state_dict()
|
||||
for key in state_dict:
|
||||
if '.mobile_inverted_conv.' in key:
|
||||
new_key = key.replace('.mobile_inverted_conv.', '.conv.')
|
||||
else:
|
||||
new_key = key
|
||||
if new_key in model_dict:
|
||||
pass
|
||||
elif '.bn.bn.' in new_key:
|
||||
new_key = new_key.replace('.bn.bn.', '.bn.')
|
||||
elif '.conv.conv.weight' in new_key:
|
||||
new_key = new_key.replace('.conv.conv.weight', '.conv.weight')
|
||||
elif '.linear.linear.' in new_key:
|
||||
new_key = new_key.replace('.linear.linear.', '.linear.')
|
||||
##############################################################################
|
||||
elif '.linear.' in new_key:
|
||||
new_key = new_key.replace('.linear.', '.linear.linear.')
|
||||
elif 'bn.' in new_key:
|
||||
new_key = new_key.replace('bn.', 'bn.bn.')
|
||||
elif 'conv.weight' in new_key:
|
||||
new_key = new_key.replace('conv.weight', 'conv.conv.weight')
|
||||
else:
|
||||
raise ValueError(new_key)
|
||||
assert new_key in model_dict, '%s' % new_key
|
||||
model_dict[new_key] = state_dict[key]
|
||||
super(OFAMobileNetV3, self).load_state_dict(model_dict)
|
||||
|
||||
""" set, sample and get active sub-networks """
|
||||
|
||||
def set_max_net(self):
|
||||
self.set_active_subnet(ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list))
|
||||
|
||||
def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):
|
||||
ks = val2list(ks, len(self.blocks) - 1)
|
||||
expand_ratio = val2list(e, len(self.blocks) - 1)
|
||||
depth = val2list(d, len(self.block_group_info))
|
||||
|
||||
for block, k, e in zip(self.blocks[1:], ks, expand_ratio):
|
||||
if k is not None:
|
||||
block.conv.active_kernel_size = k
|
||||
if e is not None:
|
||||
block.conv.active_expand_ratio = e
|
||||
|
||||
for i, d in enumerate(depth):
|
||||
if d is not None:
|
||||
self.runtime_depth[i] = min(len(self.block_group_info[i]), d)
|
||||
|
||||
def set_constraint(self, include_list, constraint_type='depth'):
|
||||
if constraint_type == 'depth':
|
||||
self.__dict__['_depth_include_list'] = include_list.copy()
|
||||
elif constraint_type == 'expand_ratio':
|
||||
self.__dict__['_expand_include_list'] = include_list.copy()
|
||||
elif constraint_type == 'kernel_size':
|
||||
self.__dict__['_ks_include_list'] = include_list.copy()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def clear_constraint(self):
|
||||
self.__dict__['_depth_include_list'] = None
|
||||
self.__dict__['_expand_include_list'] = None
|
||||
self.__dict__['_ks_include_list'] = None
|
||||
|
||||
def sample_active_subnet(self):
|
||||
ks_candidates = self.ks_list if self.__dict__.get('_ks_include_list', None) is None \
|
||||
else self.__dict__['_ks_include_list']
|
||||
expand_candidates = self.expand_ratio_list if self.__dict__.get('_expand_include_list', None) is None \
|
||||
else self.__dict__['_expand_include_list']
|
||||
depth_candidates = self.depth_list if self.__dict__.get('_depth_include_list', None) is None else \
|
||||
self.__dict__['_depth_include_list']
|
||||
|
||||
# sample kernel size
|
||||
ks_setting = []
|
||||
if not isinstance(ks_candidates[0], list):
|
||||
ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)]
|
||||
for k_set in ks_candidates:
|
||||
k = random.choice(k_set)
|
||||
ks_setting.append(k)
|
||||
|
||||
# sample expand ratio
|
||||
expand_setting = []
|
||||
if not isinstance(expand_candidates[0], list):
|
||||
expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)]
|
||||
for e_set in expand_candidates:
|
||||
e = random.choice(e_set)
|
||||
expand_setting.append(e)
|
||||
|
||||
# sample depth
|
||||
depth_setting = []
|
||||
if not isinstance(depth_candidates[0], list):
|
||||
depth_candidates = [depth_candidates for _ in range(len(self.block_group_info))]
|
||||
for d_set in depth_candidates:
|
||||
d = random.choice(d_set)
|
||||
depth_setting.append(d)
|
||||
|
||||
import pdb; pdb.set_trace()
|
||||
self.set_active_subnet(ks_setting, expand_setting, depth_setting)
|
||||
|
||||
return {
|
||||
'ks': ks_setting,
|
||||
'e': expand_setting,
|
||||
'd': depth_setting,
|
||||
}
|
||||
|
||||
def get_active_subnet(self, preserve_weight=True):
|
||||
first_conv = copy.deepcopy(self.first_conv)
|
||||
blocks = [copy.deepcopy(self.blocks[0])]
|
||||
|
||||
final_expand_layer = copy.deepcopy(self.final_expand_layer)
|
||||
feature_mix_layer = copy.deepcopy(self.feature_mix_layer)
|
||||
classifier = copy.deepcopy(self.classifier)
|
||||
|
||||
input_channel = blocks[0].conv.out_channels
|
||||
# blocks
|
||||
for stage_id, block_idx in enumerate(self.block_group_info):
|
||||
depth = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:depth]
|
||||
stage_blocks = []
|
||||
for idx in active_idx:
|
||||
stage_blocks.append(ResidualBlock(
|
||||
self.blocks[idx].conv.get_active_subnet(input_channel, preserve_weight),
|
||||
copy.deepcopy(self.blocks[idx].shortcut),
|
||||
copy.deepcopy(self.blocks[idx].dropout_rate),
|
||||
copy.deepcopy(self.blocks[idx].dropblock),
|
||||
copy.deepcopy(self.blocks[idx].block_size),
|
||||
))
|
||||
input_channel = stage_blocks[-1].conv.out_channels
|
||||
blocks += stage_blocks
|
||||
|
||||
_subnet = MobileNetV3(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
|
||||
_subnet.set_bn_param(**self.get_bn_param())
|
||||
return _subnet
|
||||
|
||||
def get_active_net_config(self):
|
||||
# first conv
|
||||
first_conv_config = self.first_conv.config
|
||||
first_block_config = self.blocks[0].config
|
||||
final_expand_config = self.final_expand_layer.config
|
||||
feature_mix_layer_config = self.feature_mix_layer.config
|
||||
classifier_config = self.classifier.config
|
||||
|
||||
block_config_list = [first_block_config]
|
||||
input_channel = first_block_config['conv']['out_channels']
|
||||
for stage_id, block_idx in enumerate(self.block_group_info):
|
||||
depth = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:depth]
|
||||
stage_blocks = []
|
||||
for idx in active_idx:
|
||||
stage_blocks.append({
|
||||
'name': ResidualBlock.__name__,
|
||||
'conv': self.blocks[idx].conv.get_active_subnet_config(input_channel),
|
||||
'shortcut': self.blocks[idx].shortcut.config if self.blocks[idx].shortcut is not None else None,
|
||||
})
|
||||
input_channel = self.blocks[idx].conv.active_out_channel
|
||||
block_config_list += stage_blocks
|
||||
|
||||
return {
|
||||
'name': MobileNetV3.__name__,
|
||||
'bn': self.get_bn_param(),
|
||||
'first_conv': first_conv_config,
|
||||
'blocks': block_config_list,
|
||||
'final_expand_layer': final_expand_config,
|
||||
'feature_mix_layer': feature_mix_layer_config,
|
||||
'classifier': classifier_config,
|
||||
}
|
||||
|
||||
""" Width Related Methods """
|
||||
|
||||
def re_organize_middle_weights(self, expand_ratio_stage=0):
|
||||
for block in self.blocks[1:]:
|
||||
block.conv.re_organize_middle_weights(expand_ratio_stage)
|
||||
@@ -0,0 +1,331 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import copy
|
||||
import random
|
||||
|
||||
from ofa_local.utils import make_divisible, val2list, MyNetwork
|
||||
from ofa_local.imagenet_classification.elastic_nn.modules import DynamicMBConvLayer
|
||||
from ofa_local.utils.layers import ConvLayer, IdentityLayer, LinearLayer, MBConvLayer, ResidualBlock
|
||||
from ofa_local.imagenet_classification.networks.proxyless_nets import ProxylessNASNets
|
||||
|
||||
__all__ = ['OFAProxylessNASNets']
|
||||
|
||||
|
||||
class OFAProxylessNASNets(ProxylessNASNets):
|
||||
|
||||
def __init__(self, n_classes=1000, bn_param=(0.1, 1e-3), dropout_rate=0.1, base_stage_width=None, width_mult=1.0,
|
||||
ks_list=3, expand_ratio_list=6, depth_list=4):
|
||||
|
||||
self.width_mult = width_mult
|
||||
self.ks_list = val2list(ks_list, 1)
|
||||
self.expand_ratio_list = val2list(expand_ratio_list, 1)
|
||||
self.depth_list = val2list(depth_list, 1)
|
||||
|
||||
self.ks_list.sort()
|
||||
self.expand_ratio_list.sort()
|
||||
self.depth_list.sort()
|
||||
|
||||
if base_stage_width == 'google':
|
||||
# MobileNetV2 Stage Width
|
||||
base_stage_width = [32, 16, 24, 32, 64, 96, 160, 320, 1280]
|
||||
else:
|
||||
# ProxylessNAS Stage Width
|
||||
base_stage_width = [32, 16, 24, 40, 80, 96, 192, 320, 1280]
|
||||
|
||||
input_channel = make_divisible(base_stage_width[0] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
first_block_width = make_divisible(base_stage_width[1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
last_channel = make_divisible(base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
# first conv layer
|
||||
first_conv = ConvLayer(
|
||||
3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='relu6', ops_order='weight_bn_act'
|
||||
)
|
||||
# first block
|
||||
first_block_conv = MBConvLayer(
|
||||
in_channels=input_channel, out_channels=first_block_width, kernel_size=3, stride=1,
|
||||
expand_ratio=1, act_func='relu6',
|
||||
)
|
||||
first_block = ResidualBlock(first_block_conv, None)
|
||||
|
||||
input_channel = first_block_width
|
||||
# inverted residual blocks
|
||||
self.block_group_info = []
|
||||
blocks = [first_block]
|
||||
_block_index = 1
|
||||
|
||||
stride_stages = [2, 2, 2, 1, 2, 1]
|
||||
n_block_list = [max(self.depth_list)] * 5 + [1]
|
||||
|
||||
width_list = []
|
||||
for base_width in base_stage_width[2:-1]:
|
||||
width = make_divisible(base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
width_list.append(width)
|
||||
|
||||
for width, n_block, s in zip(width_list, n_block_list, stride_stages):
|
||||
self.block_group_info.append([_block_index + i for i in range(n_block)])
|
||||
_block_index += n_block
|
||||
|
||||
output_channel = width
|
||||
for i in range(n_block):
|
||||
if i == 0:
|
||||
stride = s
|
||||
else:
|
||||
stride = 1
|
||||
|
||||
mobile_inverted_conv = DynamicMBConvLayer(
|
||||
in_channel_list=val2list(input_channel, 1), out_channel_list=val2list(output_channel, 1),
|
||||
kernel_size_list=ks_list, expand_ratio_list=expand_ratio_list, stride=stride, act_func='relu6',
|
||||
)
|
||||
|
||||
if stride == 1 and input_channel == output_channel:
|
||||
shortcut = IdentityLayer(input_channel, input_channel)
|
||||
else:
|
||||
shortcut = None
|
||||
|
||||
mb_inverted_block = ResidualBlock(mobile_inverted_conv, shortcut)
|
||||
|
||||
blocks.append(mb_inverted_block)
|
||||
input_channel = output_channel
|
||||
# 1x1_conv before global average pooling
|
||||
feature_mix_layer = ConvLayer(
|
||||
input_channel, last_channel, kernel_size=1, use_bn=True, act_func='relu6',
|
||||
)
|
||||
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
||||
|
||||
super(OFAProxylessNASNets, self).__init__(first_conv, blocks, feature_mix_layer, classifier)
|
||||
|
||||
# set bn param
|
||||
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])
|
||||
|
||||
# runtime_depth
|
||||
self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]
|
||||
|
||||
""" MyNetwork required methods """
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'OFAProxylessNASNets'
|
||||
|
||||
def forward(self, x):
|
||||
# first conv
|
||||
x = self.first_conv(x)
|
||||
# first block
|
||||
x = self.blocks[0](x)
|
||||
|
||||
# blocks
|
||||
for stage_id, block_idx in enumerate(self.block_group_info):
|
||||
depth = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:depth]
|
||||
for idx in active_idx:
|
||||
x = self.blocks[idx](x)
|
||||
|
||||
# feature_mix_layer
|
||||
x = self.feature_mix_layer(x)
|
||||
x = x.mean(3).mean(2)
|
||||
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
_str = self.first_conv.module_str + '\n'
|
||||
_str += self.blocks[0].module_str + '\n'
|
||||
|
||||
for stage_id, block_idx in enumerate(self.block_group_info):
|
||||
depth = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:depth]
|
||||
for idx in active_idx:
|
||||
_str += self.blocks[idx].module_str + '\n'
|
||||
_str += self.feature_mix_layer.module_str + '\n'
|
||||
_str += self.classifier.module_str + '\n'
|
||||
return _str
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': OFAProxylessNASNets.__name__,
|
||||
'bn': self.get_bn_param(),
|
||||
'first_conv': self.first_conv.config,
|
||||
'blocks': [
|
||||
block.config for block in self.blocks
|
||||
],
|
||||
'feature_mix_layer': None if self.feature_mix_layer is None else self.feature_mix_layer.config,
|
||||
'classifier': self.classifier.config,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
raise ValueError('do not support this function')
|
||||
|
||||
@property
|
||||
def grouped_block_index(self):
|
||||
return self.block_group_info
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
model_dict = self.state_dict()
|
||||
for key in state_dict:
|
||||
if '.mobile_inverted_conv.' in key:
|
||||
new_key = key.replace('.mobile_inverted_conv.', '.conv.')
|
||||
else:
|
||||
new_key = key
|
||||
if new_key in model_dict:
|
||||
pass
|
||||
elif '.bn.bn.' in new_key:
|
||||
new_key = new_key.replace('.bn.bn.', '.bn.')
|
||||
elif '.conv.conv.weight' in new_key:
|
||||
new_key = new_key.replace('.conv.conv.weight', '.conv.weight')
|
||||
elif '.linear.linear.' in new_key:
|
||||
new_key = new_key.replace('.linear.linear.', '.linear.')
|
||||
##############################################################################
|
||||
elif '.linear.' in new_key:
|
||||
new_key = new_key.replace('.linear.', '.linear.linear.')
|
||||
elif 'bn.' in new_key:
|
||||
new_key = new_key.replace('bn.', 'bn.bn.')
|
||||
elif 'conv.weight' in new_key:
|
||||
new_key = new_key.replace('conv.weight', 'conv.conv.weight')
|
||||
else:
|
||||
raise ValueError(new_key)
|
||||
assert new_key in model_dict, '%s' % new_key
|
||||
model_dict[new_key] = state_dict[key]
|
||||
super(OFAProxylessNASNets, self).load_state_dict(model_dict)
|
||||
|
||||
""" set, sample and get active sub-networks """
|
||||
|
||||
def set_max_net(self):
|
||||
self.set_active_subnet(ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list))
|
||||
|
||||
def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):
|
||||
ks = val2list(ks, len(self.blocks) - 1)
|
||||
expand_ratio = val2list(e, len(self.blocks) - 1)
|
||||
depth = val2list(d, len(self.block_group_info))
|
||||
|
||||
for block, k, e in zip(self.blocks[1:], ks, expand_ratio):
|
||||
if k is not None:
|
||||
block.conv.active_kernel_size = k
|
||||
if e is not None:
|
||||
block.conv.active_expand_ratio = e
|
||||
|
||||
for i, d in enumerate(depth):
|
||||
if d is not None:
|
||||
self.runtime_depth[i] = min(len(self.block_group_info[i]), d)
|
||||
|
||||
def set_constraint(self, include_list, constraint_type='depth'):
|
||||
if constraint_type == 'depth':
|
||||
self.__dict__['_depth_include_list'] = include_list.copy()
|
||||
elif constraint_type == 'expand_ratio':
|
||||
self.__dict__['_expand_include_list'] = include_list.copy()
|
||||
elif constraint_type == 'kernel_size':
|
||||
self.__dict__['_ks_include_list'] = include_list.copy()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def clear_constraint(self):
|
||||
self.__dict__['_depth_include_list'] = None
|
||||
self.__dict__['_expand_include_list'] = None
|
||||
self.__dict__['_ks_include_list'] = None
|
||||
|
||||
def sample_active_subnet(self):
|
||||
ks_candidates = self.ks_list if self.__dict__.get('_ks_include_list', None) is None \
|
||||
else self.__dict__['_ks_include_list']
|
||||
expand_candidates = self.expand_ratio_list if self.__dict__.get('_expand_include_list', None) is None \
|
||||
else self.__dict__['_expand_include_list']
|
||||
depth_candidates = self.depth_list if self.__dict__.get('_depth_include_list', None) is None else \
|
||||
self.__dict__['_depth_include_list']
|
||||
|
||||
# sample kernel size
|
||||
ks_setting = []
|
||||
if not isinstance(ks_candidates[0], list):
|
||||
ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)]
|
||||
for k_set in ks_candidates:
|
||||
k = random.choice(k_set)
|
||||
ks_setting.append(k)
|
||||
|
||||
# sample expand ratio
|
||||
expand_setting = []
|
||||
if not isinstance(expand_candidates[0], list):
|
||||
expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)]
|
||||
for e_set in expand_candidates:
|
||||
e = random.choice(e_set)
|
||||
expand_setting.append(e)
|
||||
|
||||
# sample depth
|
||||
depth_setting = []
|
||||
if not isinstance(depth_candidates[0], list):
|
||||
depth_candidates = [depth_candidates for _ in range(len(self.block_group_info))]
|
||||
for d_set in depth_candidates:
|
||||
d = random.choice(d_set)
|
||||
depth_setting.append(d)
|
||||
|
||||
depth_setting[-1] = 1
|
||||
self.set_active_subnet(ks_setting, expand_setting, depth_setting)
|
||||
|
||||
return {
|
||||
'ks': ks_setting,
|
||||
'e': expand_setting,
|
||||
'd': depth_setting,
|
||||
}
|
||||
|
||||
def get_active_subnet(self, preserve_weight=True):
|
||||
first_conv = copy.deepcopy(self.first_conv)
|
||||
blocks = [copy.deepcopy(self.blocks[0])]
|
||||
feature_mix_layer = copy.deepcopy(self.feature_mix_layer)
|
||||
classifier = copy.deepcopy(self.classifier)
|
||||
|
||||
input_channel = blocks[0].conv.out_channels
|
||||
# blocks
|
||||
for stage_id, block_idx in enumerate(self.block_group_info):
|
||||
depth = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:depth]
|
||||
stage_blocks = []
|
||||
for idx in active_idx:
|
||||
stage_blocks.append(ResidualBlock(
|
||||
self.blocks[idx].conv.get_active_subnet(input_channel, preserve_weight),
|
||||
copy.deepcopy(self.blocks[idx].shortcut)
|
||||
))
|
||||
input_channel = stage_blocks[-1].conv.out_channels
|
||||
blocks += stage_blocks
|
||||
|
||||
_subnet = ProxylessNASNets(first_conv, blocks, feature_mix_layer, classifier)
|
||||
_subnet.set_bn_param(**self.get_bn_param())
|
||||
return _subnet
|
||||
|
||||
def get_active_net_config(self):
|
||||
first_conv_config = self.first_conv.config
|
||||
first_block_config = self.blocks[0].config
|
||||
feature_mix_layer_config = self.feature_mix_layer.config
|
||||
classifier_config = self.classifier.config
|
||||
|
||||
block_config_list = [first_block_config]
|
||||
input_channel = first_block_config['conv']['out_channels']
|
||||
for stage_id, block_idx in enumerate(self.block_group_info):
|
||||
depth = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:depth]
|
||||
stage_blocks = []
|
||||
for idx in active_idx:
|
||||
stage_blocks.append({
|
||||
'name': ResidualBlock.__name__,
|
||||
'conv': self.blocks[idx].conv.get_active_subnet_config(input_channel),
|
||||
'shortcut': self.blocks[idx].shortcut.config if self.blocks[idx].shortcut is not None else None,
|
||||
})
|
||||
try:
|
||||
input_channel = self.blocks[idx].conv.active_out_channel
|
||||
except Exception:
|
||||
input_channel = self.blocks[idx].conv.out_channels
|
||||
block_config_list += stage_blocks
|
||||
|
||||
return {
|
||||
'name': ProxylessNASNets.__name__,
|
||||
'bn': self.get_bn_param(),
|
||||
'first_conv': first_conv_config,
|
||||
'blocks': block_config_list,
|
||||
'feature_mix_layer': feature_mix_layer_config,
|
||||
'classifier': classifier_config,
|
||||
}
|
||||
|
||||
""" Width Related Methods """
|
||||
|
||||
def re_organize_middle_weights(self, expand_ratio_stage=0):
|
||||
for block in self.blocks[1:]:
|
||||
block.conv.re_organize_middle_weights(expand_ratio_stage)
|
||||
@@ -0,0 +1,267 @@
|
||||
import random
|
||||
|
||||
from ofa_local.imagenet_classification.elastic_nn.modules.dynamic_layers import DynamicConvLayer, DynamicLinearLayer
|
||||
from ofa_local.imagenet_classification.elastic_nn.modules.dynamic_layers import DynamicResNetBottleneckBlock
|
||||
from ofa_local.utils.layers import IdentityLayer, ResidualBlock
|
||||
from ofa_local.imagenet_classification.networks import ResNets
|
||||
from ofa_local.utils import make_divisible, val2list, MyNetwork
|
||||
|
||||
__all__ = ['OFAResNets']
|
||||
|
||||
|
||||
class OFAResNets(ResNets):
|
||||
|
||||
def __init__(self, n_classes=1000, bn_param=(0.1, 1e-5), dropout_rate=0,
|
||||
depth_list=2, expand_ratio_list=0.25, width_mult_list=1.0):
|
||||
|
||||
self.depth_list = val2list(depth_list)
|
||||
self.expand_ratio_list = val2list(expand_ratio_list)
|
||||
self.width_mult_list = val2list(width_mult_list)
|
||||
# sort
|
||||
self.depth_list.sort()
|
||||
self.expand_ratio_list.sort()
|
||||
self.width_mult_list.sort()
|
||||
|
||||
input_channel = [
|
||||
make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE) for width_mult in self.width_mult_list
|
||||
]
|
||||
mid_input_channel = [
|
||||
make_divisible(channel // 2, MyNetwork.CHANNEL_DIVISIBLE) for channel in input_channel
|
||||
]
|
||||
|
||||
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
|
||||
for i, width in enumerate(stage_width_list):
|
||||
stage_width_list[i] = [
|
||||
make_divisible(width * width_mult, MyNetwork.CHANNEL_DIVISIBLE) for width_mult in self.width_mult_list
|
||||
]
|
||||
|
||||
n_block_list = [base_depth + max(self.depth_list) for base_depth in ResNets.BASE_DEPTH_LIST]
|
||||
stride_list = [1, 2, 2, 2]
|
||||
|
||||
# build input stem
|
||||
input_stem = [
|
||||
DynamicConvLayer(val2list(3), mid_input_channel, 3, stride=2, use_bn=True, act_func='relu'),
|
||||
ResidualBlock(
|
||||
DynamicConvLayer(mid_input_channel, mid_input_channel, 3, stride=1, use_bn=True, act_func='relu'),
|
||||
IdentityLayer(mid_input_channel, mid_input_channel)
|
||||
),
|
||||
DynamicConvLayer(mid_input_channel, input_channel, 3, stride=1, use_bn=True, act_func='relu')
|
||||
]
|
||||
|
||||
# blocks
|
||||
blocks = []
|
||||
for d, width, s in zip(n_block_list, stage_width_list, stride_list):
|
||||
for i in range(d):
|
||||
stride = s if i == 0 else 1
|
||||
bottleneck_block = DynamicResNetBottleneckBlock(
|
||||
input_channel, width, expand_ratio_list=self.expand_ratio_list,
|
||||
kernel_size=3, stride=stride, act_func='relu', downsample_mode='avgpool_conv',
|
||||
)
|
||||
blocks.append(bottleneck_block)
|
||||
input_channel = width
|
||||
# classifier
|
||||
classifier = DynamicLinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
|
||||
|
||||
super(OFAResNets, self).__init__(input_stem, blocks, classifier)
|
||||
|
||||
# set bn param
|
||||
self.set_bn_param(*bn_param)
|
||||
|
||||
# runtime_depth
|
||||
self.input_stem_skipping = 0
|
||||
self.runtime_depth = [0] * len(n_block_list)
|
||||
|
||||
@property
|
||||
def ks_list(self):
|
||||
return [3]
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'OFAResNets'
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.input_stem:
|
||||
if self.input_stem_skipping > 0 \
|
||||
and isinstance(layer, ResidualBlock) and isinstance(layer.shortcut, IdentityLayer):
|
||||
pass
|
||||
else:
|
||||
x = layer(x)
|
||||
x = self.max_pooling(x)
|
||||
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
||||
depth_param = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:len(block_idx) - depth_param]
|
||||
for idx in active_idx:
|
||||
x = self.blocks[idx](x)
|
||||
x = self.global_avg_pool(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
_str = ''
|
||||
for layer in self.input_stem:
|
||||
if self.input_stem_skipping > 0 \
|
||||
and isinstance(layer, ResidualBlock) and isinstance(layer.shortcut, IdentityLayer):
|
||||
pass
|
||||
else:
|
||||
_str += layer.module_str + '\n'
|
||||
_str += 'max_pooling(ks=3, stride=2)\n'
|
||||
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
||||
depth_param = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:len(block_idx) - depth_param]
|
||||
for idx in active_idx:
|
||||
_str += self.blocks[idx].module_str + '\n'
|
||||
_str += self.global_avg_pool.__repr__() + '\n'
|
||||
_str += self.classifier.module_str
|
||||
return _str
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': OFAResNets.__name__,
|
||||
'bn': self.get_bn_param(),
|
||||
'input_stem': [
|
||||
layer.config for layer in self.input_stem
|
||||
],
|
||||
'blocks': [
|
||||
block.config for block in self.blocks
|
||||
],
|
||||
'classifier': self.classifier.config,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
raise ValueError('do not support this function')
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
model_dict = self.state_dict()
|
||||
for key in state_dict:
|
||||
new_key = key
|
||||
if new_key in model_dict:
|
||||
pass
|
||||
elif '.linear.' in new_key:
|
||||
new_key = new_key.replace('.linear.', '.linear.linear.')
|
||||
elif 'bn.' in new_key:
|
||||
new_key = new_key.replace('bn.', 'bn.bn.')
|
||||
elif 'conv.weight' in new_key:
|
||||
new_key = new_key.replace('conv.weight', 'conv.conv.weight')
|
||||
else:
|
||||
raise ValueError(new_key)
|
||||
assert new_key in model_dict, '%s' % new_key
|
||||
model_dict[new_key] = state_dict[key]
|
||||
super(OFAResNets, self).load_state_dict(model_dict)
|
||||
|
||||
""" set, sample and get active sub-networks """
|
||||
|
||||
def set_max_net(self):
|
||||
self.set_active_subnet(d=max(self.depth_list), e=max(self.expand_ratio_list), w=len(self.width_mult_list) - 1)
|
||||
|
||||
def set_active_subnet(self, d=None, e=None, w=None, **kwargs):
|
||||
depth = val2list(d, len(ResNets.BASE_DEPTH_LIST) + 1)
|
||||
expand_ratio = val2list(e, len(self.blocks))
|
||||
width_mult = val2list(w, len(ResNets.BASE_DEPTH_LIST) + 2)
|
||||
|
||||
for block, e in zip(self.blocks, expand_ratio):
|
||||
if e is not None:
|
||||
block.active_expand_ratio = e
|
||||
|
||||
if width_mult[0] is not None:
|
||||
self.input_stem[1].conv.active_out_channel = self.input_stem[0].active_out_channel = \
|
||||
self.input_stem[0].out_channel_list[width_mult[0]]
|
||||
if width_mult[1] is not None:
|
||||
self.input_stem[2].active_out_channel = self.input_stem[2].out_channel_list[width_mult[1]]
|
||||
|
||||
if depth[0] is not None:
|
||||
self.input_stem_skipping = (depth[0] != max(self.depth_list))
|
||||
for stage_id, (block_idx, d, w) in enumerate(zip(self.grouped_block_index, depth[1:], width_mult[2:])):
|
||||
if d is not None:
|
||||
self.runtime_depth[stage_id] = max(self.depth_list) - d
|
||||
if w is not None:
|
||||
for idx in block_idx:
|
||||
self.blocks[idx].active_out_channel = self.blocks[idx].out_channel_list[w]
|
||||
|
||||
def sample_active_subnet(self):
|
||||
# sample expand ratio
|
||||
expand_setting = []
|
||||
for block in self.blocks:
|
||||
expand_setting.append(random.choice(block.expand_ratio_list))
|
||||
|
||||
# sample depth
|
||||
depth_setting = [random.choice([max(self.depth_list), min(self.depth_list)])]
|
||||
for stage_id in range(len(ResNets.BASE_DEPTH_LIST)):
|
||||
depth_setting.append(random.choice(self.depth_list))
|
||||
|
||||
# sample width_mult
|
||||
width_mult_setting = [
|
||||
random.choice(list(range(len(self.input_stem[0].out_channel_list)))),
|
||||
random.choice(list(range(len(self.input_stem[2].out_channel_list)))),
|
||||
]
|
||||
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
||||
stage_first_block = self.blocks[block_idx[0]]
|
||||
width_mult_setting.append(
|
||||
random.choice(list(range(len(stage_first_block.out_channel_list))))
|
||||
)
|
||||
|
||||
arch_config = {
|
||||
'd': depth_setting,
|
||||
'e': expand_setting,
|
||||
'w': width_mult_setting
|
||||
}
|
||||
self.set_active_subnet(**arch_config)
|
||||
return arch_config
|
||||
|
||||
def get_active_subnet(self, preserve_weight=True):
|
||||
input_stem = [self.input_stem[0].get_active_subnet(3, preserve_weight)]
|
||||
if self.input_stem_skipping <= 0:
|
||||
input_stem.append(ResidualBlock(
|
||||
self.input_stem[1].conv.get_active_subnet(self.input_stem[0].active_out_channel, preserve_weight),
|
||||
IdentityLayer(self.input_stem[0].active_out_channel, self.input_stem[0].active_out_channel)
|
||||
))
|
||||
input_stem.append(self.input_stem[2].get_active_subnet(self.input_stem[0].active_out_channel, preserve_weight))
|
||||
input_channel = self.input_stem[2].active_out_channel
|
||||
|
||||
blocks = []
|
||||
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
||||
depth_param = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:len(block_idx) - depth_param]
|
||||
for idx in active_idx:
|
||||
blocks.append(self.blocks[idx].get_active_subnet(input_channel, preserve_weight))
|
||||
input_channel = self.blocks[idx].active_out_channel
|
||||
classifier = self.classifier.get_active_subnet(input_channel, preserve_weight)
|
||||
subnet = ResNets(input_stem, blocks, classifier)
|
||||
|
||||
subnet.set_bn_param(**self.get_bn_param())
|
||||
return subnet
|
||||
|
||||
def get_active_net_config(self):
|
||||
input_stem_config = [self.input_stem[0].get_active_subnet_config(3)]
|
||||
if self.input_stem_skipping <= 0:
|
||||
input_stem_config.append({
|
||||
'name': ResidualBlock.__name__,
|
||||
'conv': self.input_stem[1].conv.get_active_subnet_config(self.input_stem[0].active_out_channel),
|
||||
'shortcut': IdentityLayer(self.input_stem[0].active_out_channel, self.input_stem[0].active_out_channel),
|
||||
})
|
||||
input_stem_config.append(self.input_stem[2].get_active_subnet_config(self.input_stem[0].active_out_channel))
|
||||
input_channel = self.input_stem[2].active_out_channel
|
||||
|
||||
blocks_config = []
|
||||
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
||||
depth_param = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:len(block_idx) - depth_param]
|
||||
for idx in active_idx:
|
||||
blocks_config.append(self.blocks[idx].get_active_subnet_config(input_channel))
|
||||
input_channel = self.blocks[idx].active_out_channel
|
||||
classifier_config = self.classifier.get_active_subnet_config(input_channel)
|
||||
return {
|
||||
'name': ResNets.__name__,
|
||||
'bn': self.get_bn_param(),
|
||||
'input_stem': input_stem_config,
|
||||
'blocks': blocks_config,
|
||||
'classifier': classifier_config,
|
||||
}
|
||||
|
||||
""" Width Related Methods """
|
||||
|
||||
def re_organize_middle_weights(self, expand_ratio_stage=0):
|
||||
for block in self.blocks:
|
||||
block.re_organize_middle_weights(expand_ratio_stage)
|
||||
@@ -0,0 +1,5 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from .progressive_shrinking import *
|
||||
@@ -0,0 +1,320 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import torch.nn as nn
|
||||
import random
|
||||
import time
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
from ofa.utils import AverageMeter, cross_entropy_loss_with_soft_target
|
||||
from ofa.utils import DistributedMetric, list_mean, subset_mean, val2list, MyRandomResizedCrop
|
||||
from ofa.imagenet_classification.run_manager import DistributedRunManager
|
||||
|
||||
__all__ = [
|
||||
'validate', 'train_one_epoch', 'train', 'load_models',
|
||||
'train_elastic_depth', 'train_elastic_expand', 'train_elastic_width_mult',
|
||||
]
|
||||
|
||||
|
||||
def validate(run_manager, epoch=0, is_test=False, image_size_list=None,
|
||||
ks_list=None, expand_ratio_list=None, depth_list=None, width_mult_list=None, additional_setting=None):
|
||||
dynamic_net = run_manager.net
|
||||
if isinstance(dynamic_net, nn.DataParallel):
|
||||
dynamic_net = dynamic_net.module
|
||||
|
||||
dynamic_net.eval()
|
||||
|
||||
if image_size_list is None:
|
||||
image_size_list = val2list(run_manager.run_config.data_provider.image_size, 1)
|
||||
if ks_list is None:
|
||||
ks_list = dynamic_net.ks_list
|
||||
if expand_ratio_list is None:
|
||||
expand_ratio_list = dynamic_net.expand_ratio_list
|
||||
if depth_list is None:
|
||||
depth_list = dynamic_net.depth_list
|
||||
if width_mult_list is None:
|
||||
if 'width_mult_list' in dynamic_net.__dict__:
|
||||
width_mult_list = list(range(len(dynamic_net.width_mult_list)))
|
||||
else:
|
||||
width_mult_list = [0]
|
||||
|
||||
subnet_settings = []
|
||||
for d in depth_list:
|
||||
for e in expand_ratio_list:
|
||||
for k in ks_list:
|
||||
for w in width_mult_list:
|
||||
for img_size in image_size_list:
|
||||
subnet_settings.append([{
|
||||
'image_size': img_size,
|
||||
'd': d,
|
||||
'e': e,
|
||||
'ks': k,
|
||||
'w': w,
|
||||
}, 'R%s-D%s-E%s-K%s-W%s' % (img_size, d, e, k, w)])
|
||||
if additional_setting is not None:
|
||||
subnet_settings += additional_setting
|
||||
|
||||
losses_of_subnets, top1_of_subnets, top5_of_subnets = [], [], []
|
||||
|
||||
valid_log = ''
|
||||
for setting, name in subnet_settings:
|
||||
run_manager.write_log('-' * 30 + ' Validate %s ' % name + '-' * 30, 'train', should_print=False)
|
||||
run_manager.run_config.data_provider.assign_active_img_size(setting.pop('image_size'))
|
||||
dynamic_net.set_active_subnet(**setting)
|
||||
run_manager.write_log(dynamic_net.module_str, 'train', should_print=False)
|
||||
|
||||
run_manager.reset_running_statistics(dynamic_net)
|
||||
loss, (top1, top5) = run_manager.validate(epoch=epoch, is_test=is_test, run_str=name, net=dynamic_net)
|
||||
losses_of_subnets.append(loss)
|
||||
top1_of_subnets.append(top1)
|
||||
top5_of_subnets.append(top5)
|
||||
valid_log += '%s (%.3f), ' % (name, top1)
|
||||
|
||||
return list_mean(losses_of_subnets), list_mean(top1_of_subnets), list_mean(top5_of_subnets), valid_log
|
||||
|
||||
|
||||
def train_one_epoch(run_manager, args, epoch, warmup_epochs=0, warmup_lr=0):
|
||||
dynamic_net = run_manager.network
|
||||
distributed = isinstance(run_manager, DistributedRunManager)
|
||||
|
||||
# switch to train mode
|
||||
dynamic_net.train()
|
||||
if distributed:
|
||||
run_manager.run_config.train_loader.sampler.set_epoch(epoch)
|
||||
MyRandomResizedCrop.EPOCH = epoch
|
||||
|
||||
nBatch = len(run_manager.run_config.train_loader)
|
||||
|
||||
data_time = AverageMeter()
|
||||
losses = DistributedMetric('train_loss') if distributed else AverageMeter()
|
||||
metric_dict = run_manager.get_metric_dict()
|
||||
|
||||
with tqdm(total=nBatch,
|
||||
desc='Train Epoch #{}'.format(epoch + 1),
|
||||
disable=distributed and not run_manager.is_root) as t:
|
||||
end = time.time()
|
||||
for i, (images, labels) in enumerate(run_manager.run_config.train_loader):
|
||||
MyRandomResizedCrop.BATCH = i
|
||||
data_time.update(time.time() - end)
|
||||
if epoch < warmup_epochs:
|
||||
new_lr = run_manager.run_config.warmup_adjust_learning_rate(
|
||||
run_manager.optimizer, warmup_epochs * nBatch, nBatch, epoch, i, warmup_lr,
|
||||
)
|
||||
else:
|
||||
new_lr = run_manager.run_config.adjust_learning_rate(
|
||||
run_manager.optimizer, epoch - warmup_epochs, i, nBatch
|
||||
)
|
||||
|
||||
images, labels = images.cuda(), labels.cuda()
|
||||
target = labels
|
||||
|
||||
# soft target
|
||||
if args.kd_ratio > 0:
|
||||
args.teacher_model.train()
|
||||
with torch.no_grad():
|
||||
soft_logits = args.teacher_model(images).detach()
|
||||
soft_label = F.softmax(soft_logits, dim=1)
|
||||
|
||||
# clean gradients
|
||||
dynamic_net.zero_grad()
|
||||
|
||||
loss_of_subnets = []
|
||||
# compute output
|
||||
subnet_str = ''
|
||||
for _ in range(args.dynamic_batch_size):
|
||||
# set random seed before sampling
|
||||
subnet_seed = int('%d%.3d%.3d' % (epoch * nBatch + i, _, 0))
|
||||
random.seed(subnet_seed)
|
||||
subnet_settings = dynamic_net.sample_active_subnet()
|
||||
subnet_str += '%d: ' % _ + ','.join(['%s_%s' % (
|
||||
key, '%.1f' % subset_mean(val, 0) if isinstance(val, list) else val
|
||||
) for key, val in subnet_settings.items()]) + ' || '
|
||||
|
||||
output = run_manager.net(images)
|
||||
if args.kd_ratio == 0:
|
||||
loss = run_manager.train_criterion(output, labels)
|
||||
loss_type = 'ce'
|
||||
else:
|
||||
if args.kd_type == 'ce':
|
||||
kd_loss = cross_entropy_loss_with_soft_target(output, soft_label)
|
||||
else:
|
||||
kd_loss = F.mse_loss(output, soft_logits)
|
||||
loss = args.kd_ratio * kd_loss + run_manager.train_criterion(output, labels)
|
||||
loss_type = '%.1fkd-%s & ce' % (args.kd_ratio, args.kd_type)
|
||||
|
||||
# measure accuracy and record loss
|
||||
loss_of_subnets.append(loss)
|
||||
run_manager.update_metric(metric_dict, output, target)
|
||||
|
||||
loss.backward()
|
||||
run_manager.optimizer.step()
|
||||
|
||||
losses.update(list_mean(loss_of_subnets), images.size(0))
|
||||
|
||||
t.set_postfix({
|
||||
'loss': losses.avg.item(),
|
||||
**run_manager.get_metric_vals(metric_dict, return_dict=True),
|
||||
'R': images.size(2),
|
||||
'lr': new_lr,
|
||||
'loss_type': loss_type,
|
||||
'seed': str(subnet_seed),
|
||||
'str': subnet_str,
|
||||
'data_time': data_time.avg,
|
||||
})
|
||||
t.update(1)
|
||||
end = time.time()
|
||||
return losses.avg.item(), run_manager.get_metric_vals(metric_dict)
|
||||
|
||||
|
||||
def train(run_manager, args, validate_func=None):
|
||||
distributed = isinstance(run_manager, DistributedRunManager)
|
||||
if validate_func is None:
|
||||
validate_func = validate
|
||||
|
||||
for epoch in range(run_manager.start_epoch, run_manager.run_config.n_epochs + args.warmup_epochs):
|
||||
train_loss, (train_top1, train_top5) = train_one_epoch(
|
||||
run_manager, args, epoch, args.warmup_epochs, args.warmup_lr)
|
||||
|
||||
if (epoch + 1) % args.validation_frequency == 0:
|
||||
val_loss, val_acc, val_acc5, _val_log = validate_func(run_manager, epoch=epoch, is_test=False)
|
||||
# best_acc
|
||||
is_best = val_acc > run_manager.best_acc
|
||||
run_manager.best_acc = max(run_manager.best_acc, val_acc)
|
||||
if not distributed or run_manager.is_root:
|
||||
val_log = 'Valid [{0}/{1}] loss={2:.3f}, top-1={3:.3f} ({4:.3f})'. \
|
||||
format(epoch + 1 - args.warmup_epochs, run_manager.run_config.n_epochs, val_loss, val_acc,
|
||||
run_manager.best_acc)
|
||||
val_log += ', Train top-1 {top1:.3f}, Train loss {loss:.3f}\t'.format(top1=train_top1, loss=train_loss)
|
||||
val_log += _val_log
|
||||
run_manager.write_log(val_log, 'valid', should_print=False)
|
||||
|
||||
run_manager.save_model({
|
||||
'epoch': epoch,
|
||||
'best_acc': run_manager.best_acc,
|
||||
'optimizer': run_manager.optimizer.state_dict(),
|
||||
'state_dict': run_manager.network.state_dict(),
|
||||
}, is_best=is_best)
|
||||
|
||||
|
||||
def load_models(run_manager, dynamic_net, model_path=None):
|
||||
# specify init path
|
||||
init = torch.load(model_path, map_location='cpu')['state_dict']
|
||||
dynamic_net.load_state_dict(init)
|
||||
run_manager.write_log('Loaded init from %s' % model_path, 'valid')
|
||||
|
||||
|
||||
def train_elastic_depth(train_func, run_manager, args, validate_func_dict):
|
||||
dynamic_net = run_manager.net
|
||||
if isinstance(dynamic_net, nn.DataParallel):
|
||||
dynamic_net = dynamic_net.module
|
||||
|
||||
depth_stage_list = dynamic_net.depth_list.copy()
|
||||
depth_stage_list.sort(reverse=True)
|
||||
n_stages = len(depth_stage_list) - 1
|
||||
current_stage = n_stages - 1
|
||||
|
||||
# load pretrained models
|
||||
if run_manager.start_epoch == 0 and not args.resume:
|
||||
validate_func_dict['depth_list'] = sorted(dynamic_net.depth_list)
|
||||
|
||||
load_models(run_manager, dynamic_net, model_path=args.ofa_checkpoint_path)
|
||||
# validate after loading weights
|
||||
run_manager.write_log('%.3f\t%.3f\t%.3f\t%s' %
|
||||
validate(run_manager, is_test=True, **validate_func_dict), 'valid')
|
||||
else:
|
||||
assert args.resume
|
||||
|
||||
run_manager.write_log(
|
||||
'-' * 30 + 'Supporting Elastic Depth: %s -> %s' %
|
||||
(depth_stage_list[:current_stage + 1], depth_stage_list[:current_stage + 2]) + '-' * 30, 'valid'
|
||||
)
|
||||
# add depth list constraints
|
||||
if len(set(dynamic_net.ks_list)) == 1 and len(set(dynamic_net.expand_ratio_list)) == 1:
|
||||
validate_func_dict['depth_list'] = depth_stage_list
|
||||
else:
|
||||
validate_func_dict['depth_list'] = sorted({min(depth_stage_list), max(depth_stage_list)})
|
||||
|
||||
# train
|
||||
train_func(
|
||||
run_manager, args,
|
||||
lambda _run_manager, epoch, is_test: validate(_run_manager, epoch, is_test, **validate_func_dict)
|
||||
)
|
||||
|
||||
|
||||
def train_elastic_expand(train_func, run_manager, args, validate_func_dict):
|
||||
dynamic_net = run_manager.net
|
||||
if isinstance(dynamic_net, nn.DataParallel):
|
||||
dynamic_net = dynamic_net.module
|
||||
|
||||
expand_stage_list = dynamic_net.expand_ratio_list.copy()
|
||||
expand_stage_list.sort(reverse=True)
|
||||
n_stages = len(expand_stage_list) - 1
|
||||
current_stage = n_stages - 1
|
||||
|
||||
# load pretrained models
|
||||
if run_manager.start_epoch == 0 and not args.resume:
|
||||
validate_func_dict['expand_ratio_list'] = sorted(dynamic_net.expand_ratio_list)
|
||||
|
||||
load_models(run_manager, dynamic_net, model_path=args.ofa_checkpoint_path)
|
||||
dynamic_net.re_organize_middle_weights(expand_ratio_stage=current_stage)
|
||||
run_manager.write_log('%.3f\t%.3f\t%.3f\t%s' %
|
||||
validate(run_manager, is_test=True, **validate_func_dict), 'valid')
|
||||
else:
|
||||
assert args.resume
|
||||
|
||||
run_manager.write_log(
|
||||
'-' * 30 + 'Supporting Elastic Expand Ratio: %s -> %s' %
|
||||
(expand_stage_list[:current_stage + 1], expand_stage_list[:current_stage + 2]) + '-' * 30, 'valid'
|
||||
)
|
||||
if len(set(dynamic_net.ks_list)) == 1 and len(set(dynamic_net.depth_list)) == 1:
|
||||
validate_func_dict['expand_ratio_list'] = expand_stage_list
|
||||
else:
|
||||
validate_func_dict['expand_ratio_list'] = sorted({min(expand_stage_list), max(expand_stage_list)})
|
||||
|
||||
# train
|
||||
train_func(
|
||||
run_manager, args,
|
||||
lambda _run_manager, epoch, is_test: validate(_run_manager, epoch, is_test, **validate_func_dict)
|
||||
)
|
||||
|
||||
|
||||
def train_elastic_width_mult(train_func, run_manager, args, validate_func_dict):
|
||||
dynamic_net = run_manager.net
|
||||
if isinstance(dynamic_net, nn.DataParallel):
|
||||
dynamic_net = dynamic_net.module
|
||||
|
||||
width_stage_list = dynamic_net.width_mult_list.copy()
|
||||
width_stage_list.sort(reverse=True)
|
||||
n_stages = len(width_stage_list) - 1
|
||||
current_stage = n_stages - 1
|
||||
|
||||
if run_manager.start_epoch == 0 and not args.resume:
|
||||
load_models(run_manager, dynamic_net, model_path=args.ofa_checkpoint_path)
|
||||
if current_stage == 0:
|
||||
dynamic_net.re_organize_middle_weights(expand_ratio_stage=len(dynamic_net.expand_ratio_list) - 1)
|
||||
run_manager.write_log('reorganize_middle_weights (expand_ratio_stage=%d)'
|
||||
% (len(dynamic_net.expand_ratio_list) - 1), 'valid')
|
||||
try:
|
||||
dynamic_net.re_organize_outer_weights()
|
||||
run_manager.write_log('reorganize_outer_weights', 'valid')
|
||||
except Exception:
|
||||
pass
|
||||
run_manager.write_log('%.3f\t%.3f\t%.3f\t%s' %
|
||||
validate(run_manager, is_test=True, **validate_func_dict), 'valid')
|
||||
else:
|
||||
assert args.resume
|
||||
|
||||
run_manager.write_log(
|
||||
'-' * 30 + 'Supporting Elastic Width Mult: %s -> %s' %
|
||||
(width_stage_list[:current_stage + 1], width_stage_list[:current_stage + 2]) + '-' * 30, 'valid'
|
||||
)
|
||||
validate_func_dict['width_mult_list'] = sorted({0, len(width_stage_list) - 1})
|
||||
|
||||
# train
|
||||
train_func(
|
||||
run_manager, args,
|
||||
lambda _run_manager, epoch, is_test: validate(_run_manager, epoch, is_test, **validate_func_dict)
|
||||
)
|
||||
@@ -0,0 +1,70 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import copy
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
from ofa_local.utils import AverageMeter, get_net_device, DistributedTensor
|
||||
from ofa_local.imagenet_classification.elastic_nn.modules.dynamic_op import DynamicBatchNorm2d
|
||||
|
||||
__all__ = ['set_running_statistics']
|
||||
|
||||
|
||||
def set_running_statistics(model, data_loader, distributed=False):
|
||||
bn_mean = {}
|
||||
bn_var = {}
|
||||
|
||||
forward_model = copy.deepcopy(model)
|
||||
for name, m in forward_model.named_modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
if distributed:
|
||||
bn_mean[name] = DistributedTensor(name + '#mean')
|
||||
bn_var[name] = DistributedTensor(name + '#var')
|
||||
else:
|
||||
bn_mean[name] = AverageMeter()
|
||||
bn_var[name] = AverageMeter()
|
||||
|
||||
def new_forward(bn, mean_est, var_est):
|
||||
def lambda_forward(x):
|
||||
batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1
|
||||
batch_var = (x - batch_mean) * (x - batch_mean)
|
||||
batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
|
||||
|
||||
batch_mean = torch.squeeze(batch_mean)
|
||||
batch_var = torch.squeeze(batch_var)
|
||||
|
||||
mean_est.update(batch_mean.data, x.size(0))
|
||||
var_est.update(batch_var.data, x.size(0))
|
||||
|
||||
# bn forward using calculated mean & var
|
||||
_feature_dim = batch_mean.size(0)
|
||||
return F.batch_norm(
|
||||
x, batch_mean, batch_var, bn.weight[:_feature_dim],
|
||||
bn.bias[:_feature_dim], False,
|
||||
0.0, bn.eps,
|
||||
)
|
||||
|
||||
return lambda_forward
|
||||
|
||||
m.forward = new_forward(m, bn_mean[name], bn_var[name])
|
||||
|
||||
if len(bn_mean) == 0:
|
||||
# skip if there is no batch normalization layers in the network
|
||||
return
|
||||
|
||||
with torch.no_grad():
|
||||
DynamicBatchNorm2d.SET_RUNNING_STATISTICS = True
|
||||
for images, labels in data_loader:
|
||||
images = images.to(get_net_device(forward_model))
|
||||
forward_model(images)
|
||||
DynamicBatchNorm2d.SET_RUNNING_STATISTICS = False
|
||||
|
||||
for name, m in model.named_modules():
|
||||
if name in bn_mean and bn_mean[name].count > 0:
|
||||
feature_dim = bn_mean[name].avg.size(0)
|
||||
assert isinstance(m, nn.BatchNorm2d)
|
||||
m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg)
|
||||
m.running_var.data[:feature_dim].copy_(bn_var[name].avg)
|
||||
@@ -0,0 +1,18 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from .proxyless_nets import *
|
||||
from .mobilenet_v3 import *
|
||||
from .resnets import *
|
||||
|
||||
|
||||
def get_net_by_name(name):
|
||||
if name == ProxylessNASNets.__name__:
|
||||
return ProxylessNASNets
|
||||
elif name == MobileNetV3.__name__:
|
||||
return MobileNetV3
|
||||
elif name == ResNets.__name__:
|
||||
return ResNets
|
||||
else:
|
||||
raise ValueError('unrecognized type of network: %s' % name)
|
||||
@@ -0,0 +1,218 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import copy
|
||||
import torch.nn as nn
|
||||
|
||||
from ofa_local.utils.layers import set_layer_from_config, MBConvLayer, ConvLayer, IdentityLayer, LinearLayer, ResidualBlock
|
||||
from ofa_local.utils import MyNetwork, make_divisible, MyGlobalAvgPool2d
|
||||
|
||||
__all__ = ['MobileNetV3', 'MobileNetV3Large']
|
||||
|
||||
|
||||
class MobileNetV3(MyNetwork):
|
||||
|
||||
def __init__(self, first_conv, blocks, final_expand_layer, feature_mix_layer, classifier):
|
||||
super(MobileNetV3, self).__init__()
|
||||
|
||||
self.first_conv = first_conv
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
self.final_expand_layer = final_expand_layer
|
||||
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=True)
|
||||
self.feature_mix_layer = feature_mix_layer
|
||||
self.classifier = classifier
|
||||
|
||||
def forward(self, x):
|
||||
x = self.first_conv(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.final_expand_layer(x)
|
||||
x = self.global_avg_pool(x) # global average pooling
|
||||
x = self.feature_mix_layer(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
_str = self.first_conv.module_str + '\n'
|
||||
for block in self.blocks:
|
||||
_str += block.module_str + '\n'
|
||||
_str += self.final_expand_layer.module_str + '\n'
|
||||
_str += self.global_avg_pool.__repr__() + '\n'
|
||||
_str += self.feature_mix_layer.module_str + '\n'
|
||||
_str += self.classifier.module_str
|
||||
return _str
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': MobileNetV3.__name__,
|
||||
'bn': self.get_bn_param(),
|
||||
'first_conv': self.first_conv.config,
|
||||
'blocks': [
|
||||
block.config for block in self.blocks
|
||||
],
|
||||
'final_expand_layer': self.final_expand_layer.config,
|
||||
'feature_mix_layer': self.feature_mix_layer.config,
|
||||
'classifier': self.classifier.config,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
first_conv = set_layer_from_config(config['first_conv'])
|
||||
final_expand_layer = set_layer_from_config(config['final_expand_layer'])
|
||||
feature_mix_layer = set_layer_from_config(config['feature_mix_layer'])
|
||||
classifier = set_layer_from_config(config['classifier'])
|
||||
|
||||
blocks = []
|
||||
for block_config in config['blocks']:
|
||||
blocks.append(ResidualBlock.build_from_config(block_config))
|
||||
|
||||
net = MobileNetV3(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
|
||||
if 'bn' in config:
|
||||
net.set_bn_param(**config['bn'])
|
||||
else:
|
||||
net.set_bn_param(momentum=0.1, eps=1e-5)
|
||||
|
||||
return net
|
||||
|
||||
def zero_last_gamma(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResidualBlock):
|
||||
if isinstance(m.conv, MBConvLayer) and isinstance(m.shortcut, IdentityLayer):
|
||||
m.conv.point_linear.bn.weight.data.zero_()
|
||||
|
||||
@property
|
||||
def grouped_block_index(self):
|
||||
info_list = []
|
||||
block_index_list = []
|
||||
for i, block in enumerate(self.blocks[1:], 1):
|
||||
if block.shortcut is None and len(block_index_list) > 0:
|
||||
info_list.append(block_index_list)
|
||||
block_index_list = []
|
||||
block_index_list.append(i)
|
||||
if len(block_index_list) > 0:
|
||||
info_list.append(block_index_list)
|
||||
return info_list
|
||||
|
||||
@staticmethod
|
||||
def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate):
|
||||
# first conv layer
|
||||
first_conv = ConvLayer(
|
||||
3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='h_swish', ops_order='weight_bn_act'
|
||||
)
|
||||
# build mobile blocks
|
||||
feature_dim = input_channel
|
||||
blocks = []
|
||||
for stage_id, block_config_list in cfg.items():
|
||||
for k, mid_channel, out_channel, use_se, act_func, stride, expand_ratio in block_config_list:
|
||||
mb_conv = MBConvLayer(
|
||||
feature_dim, out_channel, k, stride, expand_ratio, mid_channel, act_func, use_se
|
||||
)
|
||||
if stride == 1 and out_channel == feature_dim:
|
||||
shortcut = IdentityLayer(out_channel, out_channel)
|
||||
else:
|
||||
shortcut = None
|
||||
blocks.append(ResidualBlock(mb_conv, shortcut))
|
||||
feature_dim = out_channel
|
||||
# final expand layer
|
||||
final_expand_layer = ConvLayer(
|
||||
feature_dim, feature_dim * 6, kernel_size=1, use_bn=True, act_func='h_swish', ops_order='weight_bn_act',
|
||||
)
|
||||
# feature mix layer
|
||||
feature_mix_layer = ConvLayer(
|
||||
feature_dim * 6, last_channel, kernel_size=1, bias=False, use_bn=False, act_func='h_swish',
|
||||
)
|
||||
# classifier
|
||||
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
||||
|
||||
return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
|
||||
|
||||
@staticmethod
|
||||
def adjust_cfg(cfg, ks=None, expand_ratio=None, depth_param=None, stage_width_list=None):
|
||||
for i, (stage_id, block_config_list) in enumerate(cfg.items()):
|
||||
for block_config in block_config_list:
|
||||
if ks is not None and stage_id != '0':
|
||||
block_config[0] = ks
|
||||
if expand_ratio is not None and stage_id != '0':
|
||||
block_config[-1] = expand_ratio
|
||||
block_config[1] = None
|
||||
if stage_width_list is not None:
|
||||
block_config[2] = stage_width_list[i]
|
||||
if depth_param is not None and stage_id != '0':
|
||||
new_block_config_list = [block_config_list[0]]
|
||||
new_block_config_list += [copy.deepcopy(block_config_list[-1]) for _ in range(depth_param - 1)]
|
||||
cfg[stage_id] = new_block_config_list
|
||||
return cfg
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
current_state_dict = self.state_dict()
|
||||
|
||||
for key in state_dict:
|
||||
if key not in current_state_dict:
|
||||
assert '.mobile_inverted_conv.' in key
|
||||
new_key = key.replace('.mobile_inverted_conv.', '.conv.')
|
||||
else:
|
||||
new_key = key
|
||||
current_state_dict[new_key] = state_dict[key]
|
||||
super(MobileNetV3, self).load_state_dict(current_state_dict)
|
||||
|
||||
|
||||
class MobileNetV3Large(MobileNetV3):
|
||||
|
||||
def __init__(self, n_classes=1000, width_mult=1.0, bn_param=(0.1, 1e-5), dropout_rate=0.2,
|
||||
ks=None, expand_ratio=None, depth_param=None, stage_width_list=None):
|
||||
input_channel = 16
|
||||
last_channel = 1280
|
||||
|
||||
input_channel = make_divisible(input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
last_channel = make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE) \
|
||||
if width_mult > 1.0 else last_channel
|
||||
|
||||
cfg = {
|
||||
# k, exp, c, se, nl, s, e,
|
||||
'0': [
|
||||
[3, 16, 16, False, 'relu', 1, 1],
|
||||
],
|
||||
'1': [
|
||||
[3, 64, 24, False, 'relu', 2, None], # 4
|
||||
[3, 72, 24, False, 'relu', 1, None], # 3
|
||||
],
|
||||
'2': [
|
||||
[5, 72, 40, True, 'relu', 2, None], # 3
|
||||
[5, 120, 40, True, 'relu', 1, None], # 3
|
||||
[5, 120, 40, True, 'relu', 1, None], # 3
|
||||
],
|
||||
'3': [
|
||||
[3, 240, 80, False, 'h_swish', 2, None], # 6
|
||||
[3, 200, 80, False, 'h_swish', 1, None], # 2.5
|
||||
[3, 184, 80, False, 'h_swish', 1, None], # 2.3
|
||||
[3, 184, 80, False, 'h_swish', 1, None], # 2.3
|
||||
],
|
||||
'4': [
|
||||
[3, 480, 112, True, 'h_swish', 1, None], # 6
|
||||
[3, 672, 112, True, 'h_swish', 1, None], # 6
|
||||
],
|
||||
'5': [
|
||||
[5, 672, 160, True, 'h_swish', 2, None], # 6
|
||||
[5, 960, 160, True, 'h_swish', 1, None], # 6
|
||||
[5, 960, 160, True, 'h_swish', 1, None], # 6
|
||||
]
|
||||
}
|
||||
|
||||
cfg = self.adjust_cfg(cfg, ks, expand_ratio, depth_param, stage_width_list)
|
||||
# width multiplier on mobile setting, change `exp: 1` and `c: 2`
|
||||
for stage_id, block_config_list in cfg.items():
|
||||
for block_config in block_config_list:
|
||||
if block_config[1] is not None:
|
||||
block_config[1] = make_divisible(block_config[1] * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
block_config[2] = make_divisible(block_config[2] * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
first_conv, blocks, final_expand_layer, feature_mix_layer, classifier = self.build_net_via_cfg(
|
||||
cfg, input_channel, last_channel, n_classes, dropout_rate
|
||||
)
|
||||
super(MobileNetV3Large, self).__init__(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
|
||||
# set bn param
|
||||
self.set_bn_param(*bn_param)
|
||||
@@ -0,0 +1,210 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import json
|
||||
import torch.nn as nn
|
||||
|
||||
from ofa_local.utils.layers import set_layer_from_config, MBConvLayer, ConvLayer, IdentityLayer, LinearLayer, ResidualBlock
|
||||
from ofa_local.utils import download_url, make_divisible, val2list, MyNetwork, MyGlobalAvgPool2d
|
||||
|
||||
__all__ = ['proxyless_base', 'ProxylessNASNets', 'MobileNetV2']
|
||||
|
||||
|
||||
def proxyless_base(net_config=None, n_classes=None, bn_param=None, dropout_rate=None,
|
||||
local_path='~/.torch/proxylessnas/'):
|
||||
assert net_config is not None, 'Please input a network config'
|
||||
if 'http' in net_config:
|
||||
net_config_path = download_url(net_config, local_path)
|
||||
else:
|
||||
net_config_path = net_config
|
||||
net_config_json = json.load(open(net_config_path, 'r'))
|
||||
|
||||
if n_classes is not None:
|
||||
net_config_json['classifier']['out_features'] = n_classes
|
||||
if dropout_rate is not None:
|
||||
net_config_json['classifier']['dropout_rate'] = dropout_rate
|
||||
|
||||
net = ProxylessNASNets.build_from_config(net_config_json)
|
||||
if bn_param is not None:
|
||||
net.set_bn_param(*bn_param)
|
||||
|
||||
return net
|
||||
|
||||
|
||||
class ProxylessNASNets(MyNetwork):
|
||||
|
||||
def __init__(self, first_conv, blocks, feature_mix_layer, classifier):
|
||||
super(ProxylessNASNets, self).__init__()
|
||||
|
||||
self.first_conv = first_conv
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
self.feature_mix_layer = feature_mix_layer
|
||||
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False)
|
||||
self.classifier = classifier
|
||||
|
||||
def forward(self, x):
|
||||
x = self.first_conv(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
if self.feature_mix_layer is not None:
|
||||
x = self.feature_mix_layer(x)
|
||||
x = self.global_avg_pool(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
_str = self.first_conv.module_str + '\n'
|
||||
for block in self.blocks:
|
||||
_str += block.module_str + '\n'
|
||||
_str += self.feature_mix_layer.module_str + '\n'
|
||||
_str += self.global_avg_pool.__repr__() + '\n'
|
||||
_str += self.classifier.module_str
|
||||
return _str
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': ProxylessNASNets.__name__,
|
||||
'bn': self.get_bn_param(),
|
||||
'first_conv': self.first_conv.config,
|
||||
'blocks': [
|
||||
block.config for block in self.blocks
|
||||
],
|
||||
'feature_mix_layer': None if self.feature_mix_layer is None else self.feature_mix_layer.config,
|
||||
'classifier': self.classifier.config,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
first_conv = set_layer_from_config(config['first_conv'])
|
||||
feature_mix_layer = set_layer_from_config(config['feature_mix_layer'])
|
||||
classifier = set_layer_from_config(config['classifier'])
|
||||
|
||||
blocks = []
|
||||
for block_config in config['blocks']:
|
||||
blocks.append(ResidualBlock.build_from_config(block_config))
|
||||
|
||||
net = ProxylessNASNets(first_conv, blocks, feature_mix_layer, classifier)
|
||||
if 'bn' in config:
|
||||
net.set_bn_param(**config['bn'])
|
||||
else:
|
||||
net.set_bn_param(momentum=0.1, eps=1e-3)
|
||||
|
||||
return net
|
||||
|
||||
def zero_last_gamma(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResidualBlock):
|
||||
if isinstance(m.conv, MBConvLayer) and isinstance(m.shortcut, IdentityLayer):
|
||||
m.conv.point_linear.bn.weight.data.zero_()
|
||||
|
||||
@property
|
||||
def grouped_block_index(self):
|
||||
info_list = []
|
||||
block_index_list = []
|
||||
for i, block in enumerate(self.blocks[1:], 1):
|
||||
if block.shortcut is None and len(block_index_list) > 0:
|
||||
info_list.append(block_index_list)
|
||||
block_index_list = []
|
||||
block_index_list.append(i)
|
||||
if len(block_index_list) > 0:
|
||||
info_list.append(block_index_list)
|
||||
return info_list
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
current_state_dict = self.state_dict()
|
||||
|
||||
for key in state_dict:
|
||||
if key not in current_state_dict:
|
||||
assert '.mobile_inverted_conv.' in key
|
||||
new_key = key.replace('.mobile_inverted_conv.', '.conv.')
|
||||
else:
|
||||
new_key = key
|
||||
current_state_dict[new_key] = state_dict[key]
|
||||
super(ProxylessNASNets, self).load_state_dict(current_state_dict)
|
||||
|
||||
|
||||
class MobileNetV2(ProxylessNASNets):
|
||||
|
||||
def __init__(self, n_classes=1000, width_mult=1.0, bn_param=(0.1, 1e-3), dropout_rate=0.2,
|
||||
ks=None, expand_ratio=None, depth_param=None, stage_width_list=None):
|
||||
|
||||
ks = 3 if ks is None else ks
|
||||
expand_ratio = 6 if expand_ratio is None else expand_ratio
|
||||
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
|
||||
input_channel = make_divisible(input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
last_channel = make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE) \
|
||||
if width_mult > 1.0 else last_channel
|
||||
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[expand_ratio, 24, 2, 2],
|
||||
[expand_ratio, 32, 3, 2],
|
||||
[expand_ratio, 64, 4, 2],
|
||||
[expand_ratio, 96, 3, 1],
|
||||
[expand_ratio, 160, 3, 2],
|
||||
[expand_ratio, 320, 1, 1],
|
||||
]
|
||||
|
||||
if depth_param is not None:
|
||||
assert isinstance(depth_param, int)
|
||||
for i in range(1, len(inverted_residual_setting) - 1):
|
||||
inverted_residual_setting[i][2] = depth_param
|
||||
|
||||
if stage_width_list is not None:
|
||||
for i in range(len(inverted_residual_setting)):
|
||||
inverted_residual_setting[i][1] = stage_width_list[i]
|
||||
|
||||
ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1)
|
||||
_pt = 0
|
||||
|
||||
# first conv layer
|
||||
first_conv = ConvLayer(
|
||||
3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='relu6', ops_order='weight_bn_act'
|
||||
)
|
||||
# inverted residual blocks
|
||||
blocks = []
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = make_divisible(c * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
for i in range(n):
|
||||
if i == 0:
|
||||
stride = s
|
||||
else:
|
||||
stride = 1
|
||||
if t == 1:
|
||||
kernel_size = 3
|
||||
else:
|
||||
kernel_size = ks[_pt]
|
||||
_pt += 1
|
||||
mobile_inverted_conv = MBConvLayer(
|
||||
in_channels=input_channel, out_channels=output_channel, kernel_size=kernel_size, stride=stride,
|
||||
expand_ratio=t,
|
||||
)
|
||||
if stride == 1:
|
||||
if input_channel == output_channel:
|
||||
shortcut = IdentityLayer(input_channel, input_channel)
|
||||
else:
|
||||
shortcut = None
|
||||
else:
|
||||
shortcut = None
|
||||
blocks.append(
|
||||
ResidualBlock(mobile_inverted_conv, shortcut)
|
||||
)
|
||||
input_channel = output_channel
|
||||
# 1x1_conv before global average pooling
|
||||
feature_mix_layer = ConvLayer(
|
||||
input_channel, last_channel, kernel_size=1, use_bn=True, act_func='relu6', ops_order='weight_bn_act',
|
||||
)
|
||||
|
||||
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
||||
|
||||
super(MobileNetV2, self).__init__(first_conv, blocks, feature_mix_layer, classifier)
|
||||
|
||||
# set bn param
|
||||
self.set_bn_param(*bn_param)
|
||||
@@ -0,0 +1,192 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from ofa_local.utils.layers import set_layer_from_config, ConvLayer, IdentityLayer, LinearLayer
|
||||
from ofa_local.utils.layers import ResNetBottleneckBlock, ResidualBlock
|
||||
from ofa_local.utils import make_divisible, MyNetwork, MyGlobalAvgPool2d
|
||||
|
||||
__all__ = ['ResNets', 'ResNet50', 'ResNet50D']
|
||||
|
||||
|
||||
class ResNets(MyNetwork):
|
||||
|
||||
BASE_DEPTH_LIST = [2, 2, 4, 2]
|
||||
STAGE_WIDTH_LIST = [256, 512, 1024, 2048]
|
||||
|
||||
def __init__(self, input_stem, blocks, classifier):
|
||||
super(ResNets, self).__init__()
|
||||
|
||||
self.input_stem = nn.ModuleList(input_stem)
|
||||
self.max_pooling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False)
|
||||
self.classifier = classifier
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.input_stem:
|
||||
x = layer(x)
|
||||
x = self.max_pooling(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.global_avg_pool(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
_str = ''
|
||||
for layer in self.input_stem:
|
||||
_str += layer.module_str + '\n'
|
||||
_str += 'max_pooling(ks=3, stride=2)\n'
|
||||
for block in self.blocks:
|
||||
_str += block.module_str + '\n'
|
||||
_str += self.global_avg_pool.__repr__() + '\n'
|
||||
_str += self.classifier.module_str
|
||||
return _str
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': ResNets.__name__,
|
||||
'bn': self.get_bn_param(),
|
||||
'input_stem': [
|
||||
layer.config for layer in self.input_stem
|
||||
],
|
||||
'blocks': [
|
||||
block.config for block in self.blocks
|
||||
],
|
||||
'classifier': self.classifier.config,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
classifier = set_layer_from_config(config['classifier'])
|
||||
|
||||
input_stem = []
|
||||
for layer_config in config['input_stem']:
|
||||
input_stem.append(set_layer_from_config(layer_config))
|
||||
blocks = []
|
||||
for block_config in config['blocks']:
|
||||
blocks.append(set_layer_from_config(block_config))
|
||||
|
||||
net = ResNets(input_stem, blocks, classifier)
|
||||
if 'bn' in config:
|
||||
net.set_bn_param(**config['bn'])
|
||||
else:
|
||||
net.set_bn_param(momentum=0.1, eps=1e-5)
|
||||
|
||||
return net
|
||||
|
||||
def zero_last_gamma(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBottleneckBlock) and isinstance(m.downsample, IdentityLayer):
|
||||
m.conv3.bn.weight.data.zero_()
|
||||
|
||||
@property
|
||||
def grouped_block_index(self):
|
||||
info_list = []
|
||||
block_index_list = []
|
||||
for i, block in enumerate(self.blocks):
|
||||
if not isinstance(block.downsample, IdentityLayer) and len(block_index_list) > 0:
|
||||
info_list.append(block_index_list)
|
||||
block_index_list = []
|
||||
block_index_list.append(i)
|
||||
if len(block_index_list) > 0:
|
||||
info_list.append(block_index_list)
|
||||
return info_list
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
super(ResNets, self).load_state_dict(state_dict)
|
||||
|
||||
|
||||
class ResNet50(ResNets):
|
||||
|
||||
def __init__(self, n_classes=1000, width_mult=1.0, bn_param=(0.1, 1e-5), dropout_rate=0,
|
||||
expand_ratio=None, depth_param=None):
|
||||
|
||||
expand_ratio = 0.25 if expand_ratio is None else expand_ratio
|
||||
|
||||
input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
|
||||
for i, width in enumerate(stage_width_list):
|
||||
stage_width_list[i] = make_divisible(width * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
depth_list = [3, 4, 6, 3]
|
||||
if depth_param is not None:
|
||||
for i, depth in enumerate(ResNets.BASE_DEPTH_LIST):
|
||||
depth_list[i] = depth + depth_param
|
||||
|
||||
stride_list = [1, 2, 2, 2]
|
||||
|
||||
# build input stem
|
||||
input_stem = [ConvLayer(
|
||||
3, input_channel, kernel_size=7, stride=2, use_bn=True, act_func='relu', ops_order='weight_bn_act',
|
||||
)]
|
||||
|
||||
# blocks
|
||||
blocks = []
|
||||
for d, width, s in zip(depth_list, stage_width_list, stride_list):
|
||||
for i in range(d):
|
||||
stride = s if i == 0 else 1
|
||||
bottleneck_block = ResNetBottleneckBlock(
|
||||
input_channel, width, kernel_size=3, stride=stride, expand_ratio=expand_ratio,
|
||||
act_func='relu', downsample_mode='conv',
|
||||
)
|
||||
blocks.append(bottleneck_block)
|
||||
input_channel = width
|
||||
# classifier
|
||||
classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
|
||||
|
||||
super(ResNet50, self).__init__(input_stem, blocks, classifier)
|
||||
|
||||
# set bn param
|
||||
self.set_bn_param(*bn_param)
|
||||
|
||||
|
||||
class ResNet50D(ResNets):
|
||||
|
||||
def __init__(self, n_classes=1000, width_mult=1.0, bn_param=(0.1, 1e-5), dropout_rate=0,
|
||||
expand_ratio=None, depth_param=None):
|
||||
|
||||
expand_ratio = 0.25 if expand_ratio is None else expand_ratio
|
||||
|
||||
input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
mid_input_channel = make_divisible(input_channel // 2, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
|
||||
for i, width in enumerate(stage_width_list):
|
||||
stage_width_list[i] = make_divisible(width * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
depth_list = [3, 4, 6, 3]
|
||||
if depth_param is not None:
|
||||
for i, depth in enumerate(ResNets.BASE_DEPTH_LIST):
|
||||
depth_list[i] = depth + depth_param
|
||||
|
||||
stride_list = [1, 2, 2, 2]
|
||||
|
||||
# build input stem
|
||||
input_stem = [
|
||||
ConvLayer(3, mid_input_channel, 3, stride=2, use_bn=True, act_func='relu'),
|
||||
ResidualBlock(
|
||||
ConvLayer(mid_input_channel, mid_input_channel, 3, stride=1, use_bn=True, act_func='relu'),
|
||||
IdentityLayer(mid_input_channel, mid_input_channel)
|
||||
),
|
||||
ConvLayer(mid_input_channel, input_channel, 3, stride=1, use_bn=True, act_func='relu')
|
||||
]
|
||||
|
||||
# blocks
|
||||
blocks = []
|
||||
for d, width, s in zip(depth_list, stage_width_list, stride_list):
|
||||
for i in range(d):
|
||||
stride = s if i == 0 else 1
|
||||
bottleneck_block = ResNetBottleneckBlock(
|
||||
input_channel, width, kernel_size=3, stride=stride, expand_ratio=expand_ratio,
|
||||
act_func='relu', downsample_mode='avgpool_conv',
|
||||
)
|
||||
blocks.append(bottleneck_block)
|
||||
input_channel = width
|
||||
# classifier
|
||||
classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
|
||||
|
||||
super(ResNet50D, self).__init__(input_stem, blocks, classifier)
|
||||
|
||||
# set bn param
|
||||
self.set_bn_param(*bn_param)
|
||||
@@ -0,0 +1,7 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from .run_config import *
|
||||
from .run_manager import *
|
||||
from .distributed_run_manager import *
|
||||
@@ -0,0 +1,381 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.backends.cudnn as cudnn
|
||||
from tqdm import tqdm
|
||||
|
||||
from ofa_local.utils import cross_entropy_with_label_smoothing, cross_entropy_loss_with_soft_target, write_log, init_models
|
||||
from ofa_local.utils import DistributedMetric, list_mean, get_net_info, accuracy, AverageMeter, mix_labels, mix_images
|
||||
from ofa_local.utils import MyRandomResizedCrop
|
||||
|
||||
__all__ = ['DistributedRunManager']
|
||||
|
||||
|
||||
class DistributedRunManager:
|
||||
|
||||
def __init__(self, path, net, run_config, hvd_compression, backward_steps=1, is_root=False, init=True):
|
||||
import horovod.torch as hvd
|
||||
|
||||
self.path = path
|
||||
self.net = net
|
||||
self.run_config = run_config
|
||||
self.is_root = is_root
|
||||
|
||||
self.best_acc = 0.0
|
||||
self.start_epoch = 0
|
||||
|
||||
os.makedirs(self.path, exist_ok=True)
|
||||
|
||||
self.net.cuda()
|
||||
cudnn.benchmark = True
|
||||
if init and self.is_root:
|
||||
init_models(self.net, self.run_config.model_init)
|
||||
if self.is_root:
|
||||
# print net info
|
||||
net_info = get_net_info(self.net, self.run_config.data_provider.data_shape)
|
||||
with open('%s/net_info.txt' % self.path, 'w') as fout:
|
||||
fout.write(json.dumps(net_info, indent=4) + '\n')
|
||||
try:
|
||||
fout.write(self.net.module_str + '\n')
|
||||
except Exception:
|
||||
fout.write('%s do not support `module_str`' % type(self.net))
|
||||
fout.write('%s\n' % self.run_config.data_provider.train.dataset.transform)
|
||||
fout.write('%s\n' % self.run_config.data_provider.test.dataset.transform)
|
||||
fout.write('%s\n' % self.net)
|
||||
|
||||
# criterion
|
||||
if isinstance(self.run_config.mixup_alpha, float):
|
||||
self.train_criterion = cross_entropy_loss_with_soft_target
|
||||
elif self.run_config.label_smoothing > 0:
|
||||
self.train_criterion = lambda pred, target: \
|
||||
cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing)
|
||||
else:
|
||||
self.train_criterion = nn.CrossEntropyLoss()
|
||||
self.test_criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# optimizer
|
||||
if self.run_config.no_decay_keys:
|
||||
keys = self.run_config.no_decay_keys.split('#')
|
||||
net_params = [
|
||||
self.net.get_parameters(keys, mode='exclude'), # parameters with weight decay
|
||||
self.net.get_parameters(keys, mode='include'), # parameters without weight decay
|
||||
]
|
||||
else:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
net_params = self.network.weight_parameters()
|
||||
except Exception:
|
||||
net_params = []
|
||||
for param in self.network.parameters():
|
||||
if param.requires_grad:
|
||||
net_params.append(param)
|
||||
self.optimizer = self.run_config.build_optimizer(net_params)
|
||||
self.optimizer = hvd.DistributedOptimizer(
|
||||
self.optimizer, named_parameters=self.net.named_parameters(), compression=hvd_compression,
|
||||
backward_passes_per_step=backward_steps,
|
||||
)
|
||||
|
||||
""" save path and log path """
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self.__dict__.get('_save_path', None) is None:
|
||||
save_path = os.path.join(self.path, 'checkpoint')
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
self.__dict__['_save_path'] = save_path
|
||||
return self.__dict__['_save_path']
|
||||
|
||||
@property
|
||||
def logs_path(self):
|
||||
if self.__dict__.get('_logs_path', None) is None:
|
||||
logs_path = os.path.join(self.path, 'logs')
|
||||
os.makedirs(logs_path, exist_ok=True)
|
||||
self.__dict__['_logs_path'] = logs_path
|
||||
return self.__dict__['_logs_path']
|
||||
|
||||
@property
|
||||
def network(self):
|
||||
return self.net
|
||||
|
||||
@network.setter
|
||||
def network(self, new_val):
|
||||
self.net = new_val
|
||||
|
||||
def write_log(self, log_str, prefix='valid', should_print=True, mode='a'):
|
||||
if self.is_root:
|
||||
write_log(self.logs_path, log_str, prefix, should_print, mode)
|
||||
|
||||
""" save & load model & save_config & broadcast """
|
||||
|
||||
def save_config(self, extra_run_config=None, extra_net_config=None):
|
||||
if self.is_root:
|
||||
run_save_path = os.path.join(self.path, 'run.config')
|
||||
if not os.path.isfile(run_save_path):
|
||||
run_config = self.run_config.config
|
||||
if extra_run_config is not None:
|
||||
run_config.update(extra_run_config)
|
||||
json.dump(run_config, open(run_save_path, 'w'), indent=4)
|
||||
print('Run configs dump to %s' % run_save_path)
|
||||
|
||||
try:
|
||||
net_save_path = os.path.join(self.path, 'net.config')
|
||||
net_config = self.net.config
|
||||
if extra_net_config is not None:
|
||||
net_config.update(extra_net_config)
|
||||
json.dump(net_config, open(net_save_path, 'w'), indent=4)
|
||||
print('Network configs dump to %s' % net_save_path)
|
||||
except Exception:
|
||||
print('%s do not support net config' % type(self.net))
|
||||
|
||||
def save_model(self, checkpoint=None, is_best=False, model_name=None):
|
||||
if self.is_root:
|
||||
if checkpoint is None:
|
||||
checkpoint = {'state_dict': self.net.state_dict()}
|
||||
|
||||
if model_name is None:
|
||||
model_name = 'checkpoint.pth.tar'
|
||||
|
||||
latest_fname = os.path.join(self.save_path, 'latest.txt')
|
||||
model_path = os.path.join(self.save_path, model_name)
|
||||
with open(latest_fname, 'w') as _fout:
|
||||
_fout.write(model_path + '\n')
|
||||
torch.save(checkpoint, model_path)
|
||||
|
||||
if is_best:
|
||||
best_path = os.path.join(self.save_path, 'model_best.pth.tar')
|
||||
torch.save({'state_dict': checkpoint['state_dict']}, best_path)
|
||||
|
||||
def load_model(self, model_fname=None):
|
||||
if self.is_root:
|
||||
latest_fname = os.path.join(self.save_path, 'latest.txt')
|
||||
if model_fname is None and os.path.exists(latest_fname):
|
||||
with open(latest_fname, 'r') as fin:
|
||||
model_fname = fin.readline()
|
||||
if model_fname[-1] == '\n':
|
||||
model_fname = model_fname[:-1]
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if model_fname is None or not os.path.exists(model_fname):
|
||||
model_fname = '%s/checkpoint.pth.tar' % self.save_path
|
||||
with open(latest_fname, 'w') as fout:
|
||||
fout.write(model_fname + '\n')
|
||||
print("=> loading checkpoint '{}'".format(model_fname))
|
||||
checkpoint = torch.load(model_fname, map_location='cpu')
|
||||
except Exception:
|
||||
self.write_log('fail to load checkpoint from %s' % self.save_path, 'valid')
|
||||
return
|
||||
|
||||
self.net.load_state_dict(checkpoint['state_dict'])
|
||||
if 'epoch' in checkpoint:
|
||||
self.start_epoch = checkpoint['epoch'] + 1
|
||||
if 'best_acc' in checkpoint:
|
||||
self.best_acc = checkpoint['best_acc']
|
||||
if 'optimizer' in checkpoint:
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
||||
self.write_log("=> loaded checkpoint '{}'".format(model_fname), 'valid')
|
||||
|
||||
# noinspection PyArgumentList
|
||||
def broadcast(self):
|
||||
import horovod.torch as hvd
|
||||
self.start_epoch = hvd.broadcast(torch.LongTensor(1).fill_(self.start_epoch)[0], 0, name='start_epoch').item()
|
||||
self.best_acc = hvd.broadcast(torch.Tensor(1).fill_(self.best_acc)[0], 0, name='best_acc').item()
|
||||
hvd.broadcast_parameters(self.net.state_dict(), 0)
|
||||
hvd.broadcast_optimizer_state(self.optimizer, 0)
|
||||
|
||||
""" metric related """
|
||||
|
||||
def get_metric_dict(self):
|
||||
return {
|
||||
'top1': DistributedMetric('top1'),
|
||||
'top5': DistributedMetric('top5'),
|
||||
}
|
||||
|
||||
def update_metric(self, metric_dict, output, labels):
|
||||
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||
metric_dict['top1'].update(acc1[0], output.size(0))
|
||||
metric_dict['top5'].update(acc5[0], output.size(0))
|
||||
|
||||
def get_metric_vals(self, metric_dict, return_dict=False):
|
||||
if return_dict:
|
||||
return {
|
||||
key: metric_dict[key].avg.item() for key in metric_dict
|
||||
}
|
||||
else:
|
||||
return [metric_dict[key].avg.item() for key in metric_dict]
|
||||
|
||||
def get_metric_names(self):
|
||||
return 'top1', 'top5'
|
||||
|
||||
""" train & validate """
|
||||
|
||||
def validate(self, epoch=0, is_test=False, run_str='', net=None, data_loader=None, no_logs=False):
|
||||
if net is None:
|
||||
net = self.net
|
||||
if data_loader is None:
|
||||
if is_test:
|
||||
data_loader = self.run_config.test_loader
|
||||
else:
|
||||
data_loader = self.run_config.valid_loader
|
||||
|
||||
net.eval()
|
||||
|
||||
losses = DistributedMetric('val_loss')
|
||||
metric_dict = self.get_metric_dict()
|
||||
|
||||
with torch.no_grad():
|
||||
with tqdm(total=len(data_loader),
|
||||
desc='Validate Epoch #{} {}'.format(epoch + 1, run_str),
|
||||
disable=no_logs or not self.is_root) as t:
|
||||
for i, (images, labels) in enumerate(data_loader):
|
||||
images, labels = images.cuda(), labels.cuda()
|
||||
# compute output
|
||||
output = net(images)
|
||||
loss = self.test_criterion(output, labels)
|
||||
# measure accuracy and record loss
|
||||
losses.update(loss, images.size(0))
|
||||
self.update_metric(metric_dict, output, labels)
|
||||
t.set_postfix({
|
||||
'loss': losses.avg.item(),
|
||||
**self.get_metric_vals(metric_dict, return_dict=True),
|
||||
'img_size': images.size(2),
|
||||
})
|
||||
t.update(1)
|
||||
return losses.avg.item(), self.get_metric_vals(metric_dict)
|
||||
|
||||
def validate_all_resolution(self, epoch=0, is_test=False, net=None):
|
||||
if net is None:
|
||||
net = self.net
|
||||
if isinstance(self.run_config.data_provider.image_size, list):
|
||||
img_size_list, loss_list, top1_list, top5_list = [], [], [], []
|
||||
for img_size in self.run_config.data_provider.image_size:
|
||||
img_size_list.append(img_size)
|
||||
self.run_config.data_provider.assign_active_img_size(img_size)
|
||||
self.reset_running_statistics(net=net)
|
||||
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
|
||||
loss_list.append(loss)
|
||||
top1_list.append(top1)
|
||||
top5_list.append(top5)
|
||||
return img_size_list, loss_list, top1_list, top5_list
|
||||
else:
|
||||
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
|
||||
return [self.run_config.data_provider.active_img_size], [loss], [top1], [top5]
|
||||
|
||||
def train_one_epoch(self, args, epoch, warmup_epochs=5, warmup_lr=0):
|
||||
self.net.train()
|
||||
self.run_config.train_loader.sampler.set_epoch(epoch) # required by distributed sampler
|
||||
MyRandomResizedCrop.EPOCH = epoch # required by elastic resolution
|
||||
|
||||
nBatch = len(self.run_config.train_loader)
|
||||
|
||||
losses = DistributedMetric('train_loss')
|
||||
metric_dict = self.get_metric_dict()
|
||||
data_time = AverageMeter()
|
||||
|
||||
with tqdm(total=nBatch,
|
||||
desc='Train Epoch #{}'.format(epoch + 1),
|
||||
disable=not self.is_root) as t:
|
||||
end = time.time()
|
||||
for i, (images, labels) in enumerate(self.run_config.train_loader):
|
||||
MyRandomResizedCrop.BATCH = i
|
||||
data_time.update(time.time() - end)
|
||||
if epoch < warmup_epochs:
|
||||
new_lr = self.run_config.warmup_adjust_learning_rate(
|
||||
self.optimizer, warmup_epochs * nBatch, nBatch, epoch, i, warmup_lr,
|
||||
)
|
||||
else:
|
||||
new_lr = self.run_config.adjust_learning_rate(self.optimizer, epoch - warmup_epochs, i, nBatch)
|
||||
|
||||
images, labels = images.cuda(), labels.cuda()
|
||||
target = labels
|
||||
if isinstance(self.run_config.mixup_alpha, float):
|
||||
# transform data
|
||||
random.seed(int('%d%.3d' % (i, epoch)))
|
||||
lam = random.betavariate(self.run_config.mixup_alpha, self.run_config.mixup_alpha)
|
||||
images = mix_images(images, lam)
|
||||
labels = mix_labels(
|
||||
labels, lam, self.run_config.data_provider.n_classes, self.run_config.label_smoothing
|
||||
)
|
||||
|
||||
# soft target
|
||||
if args.teacher_model is not None:
|
||||
args.teacher_model.train()
|
||||
with torch.no_grad():
|
||||
soft_logits = args.teacher_model(images).detach()
|
||||
soft_label = F.softmax(soft_logits, dim=1)
|
||||
|
||||
# compute output
|
||||
output = self.net(images)
|
||||
|
||||
if args.teacher_model is None:
|
||||
loss = self.train_criterion(output, labels)
|
||||
loss_type = 'ce'
|
||||
else:
|
||||
if args.kd_type == 'ce':
|
||||
kd_loss = cross_entropy_loss_with_soft_target(output, soft_label)
|
||||
else:
|
||||
kd_loss = F.mse_loss(output, soft_logits)
|
||||
loss = args.kd_ratio * kd_loss + self.train_criterion(output, labels)
|
||||
loss_type = '%.1fkd+ce' % args.kd_ratio
|
||||
|
||||
# update
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
# measure accuracy and record loss
|
||||
losses.update(loss, images.size(0))
|
||||
self.update_metric(metric_dict, output, target)
|
||||
|
||||
t.set_postfix({
|
||||
'loss': losses.avg.item(),
|
||||
**self.get_metric_vals(metric_dict, return_dict=True),
|
||||
'img_size': images.size(2),
|
||||
'lr': new_lr,
|
||||
'loss_type': loss_type,
|
||||
'data_time': data_time.avg,
|
||||
})
|
||||
t.update(1)
|
||||
end = time.time()
|
||||
|
||||
return losses.avg.item(), self.get_metric_vals(metric_dict)
|
||||
|
||||
def train(self, args, warmup_epochs=5, warmup_lr=0):
|
||||
for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epochs):
|
||||
train_loss, (train_top1, train_top5) = self.train_one_epoch(args, epoch, warmup_epochs, warmup_lr)
|
||||
img_size, val_loss, val_top1, val_top5 = self.validate_all_resolution(epoch, is_test=False)
|
||||
|
||||
is_best = list_mean(val_top1) > self.best_acc
|
||||
self.best_acc = max(self.best_acc, list_mean(val_top1))
|
||||
if self.is_root:
|
||||
val_log = '[{0}/{1}]\tloss {2:.3f}\t{6} acc {3:.3f} ({4:.3f})\t{7} acc {5:.3f}\t' \
|
||||
'Train {6} {top1:.3f}\tloss {train_loss:.3f}\t'. \
|
||||
format(epoch + 1 - warmup_epochs, self.run_config.n_epochs, list_mean(val_loss),
|
||||
list_mean(val_top1), self.best_acc, list_mean(val_top5), *self.get_metric_names(),
|
||||
top1=train_top1, train_loss=train_loss)
|
||||
for i_s, v_a in zip(img_size, val_top1):
|
||||
val_log += '(%d, %.3f), ' % (i_s, v_a)
|
||||
self.write_log(val_log, prefix='valid', should_print=False)
|
||||
|
||||
self.save_model({
|
||||
'epoch': epoch,
|
||||
'best_acc': self.best_acc,
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'state_dict': self.net.state_dict(),
|
||||
}, is_best=is_best)
|
||||
|
||||
def reset_running_statistics(self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None):
|
||||
from ofa.imagenet_classification.elastic_nn.utils import set_running_statistics
|
||||
if net is None:
|
||||
net = self.net
|
||||
if data_loader is None:
|
||||
data_loader = self.run_config.random_sub_train_loader(subset_size, subset_batch_size)
|
||||
set_running_statistics(net, data_loader)
|
||||
@@ -0,0 +1,161 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from ofa_local.utils import calc_learning_rate, build_optimizer
|
||||
from ofa_local.imagenet_classification.data_providers import ImagenetDataProvider
|
||||
|
||||
__all__ = ['RunConfig', 'ImagenetRunConfig', 'DistributedImageNetRunConfig']
|
||||
|
||||
|
||||
class RunConfig:
|
||||
|
||||
def __init__(self, n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha, model_init, validation_frequency, print_frequency):
|
||||
self.n_epochs = n_epochs
|
||||
self.init_lr = init_lr
|
||||
self.lr_schedule_type = lr_schedule_type
|
||||
self.lr_schedule_param = lr_schedule_param
|
||||
|
||||
self.dataset = dataset
|
||||
self.train_batch_size = train_batch_size
|
||||
self.test_batch_size = test_batch_size
|
||||
self.valid_size = valid_size
|
||||
|
||||
self.opt_type = opt_type
|
||||
self.opt_param = opt_param
|
||||
self.weight_decay = weight_decay
|
||||
self.label_smoothing = label_smoothing
|
||||
self.no_decay_keys = no_decay_keys
|
||||
|
||||
self.mixup_alpha = mixup_alpha
|
||||
|
||||
self.model_init = model_init
|
||||
self.validation_frequency = validation_frequency
|
||||
self.print_frequency = print_frequency
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
config = {}
|
||||
for key in self.__dict__:
|
||||
if not key.startswith('_'):
|
||||
config[key] = self.__dict__[key]
|
||||
return config
|
||||
|
||||
def copy(self):
|
||||
return RunConfig(**self.config)
|
||||
|
||||
""" learning rate """
|
||||
|
||||
def adjust_learning_rate(self, optimizer, epoch, batch=0, nBatch=None):
|
||||
""" adjust learning of a given optimizer and return the new learning rate """
|
||||
new_lr = calc_learning_rate(epoch, self.init_lr, self.n_epochs, batch, nBatch, self.lr_schedule_type)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = new_lr
|
||||
return new_lr
|
||||
|
||||
def warmup_adjust_learning_rate(self, optimizer, T_total, nBatch, epoch, batch=0, warmup_lr=0):
|
||||
T_cur = epoch * nBatch + batch + 1
|
||||
new_lr = T_cur / T_total * (self.init_lr - warmup_lr) + warmup_lr
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = new_lr
|
||||
return new_lr
|
||||
|
||||
""" data provider """
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def train_loader(self):
|
||||
return self.data_provider.train
|
||||
|
||||
@property
|
||||
def valid_loader(self):
|
||||
return self.data_provider.valid
|
||||
|
||||
@property
|
||||
def test_loader(self):
|
||||
return self.data_provider.test
|
||||
|
||||
def random_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
return self.data_provider.build_sub_train_loader(n_images, batch_size, num_worker, num_replicas, rank)
|
||||
|
||||
""" optimizer """
|
||||
|
||||
def build_optimizer(self, net_params):
|
||||
return build_optimizer(net_params,
|
||||
self.opt_type, self.opt_param, self.init_lr, self.weight_decay, self.no_decay_keys)
|
||||
|
||||
|
||||
class ImagenetRunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='imagenet', train_batch_size=256, test_batch_size=500, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys=None,
|
||||
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224, **kwargs):
|
||||
super(ImagenetRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == ImagenetDataProvider.name():
|
||||
DataProviderClass = ImagenetDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class DistributedImageNetRunConfig(ImagenetRunConfig):
|
||||
|
||||
def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='imagenet', train_batch_size=64, test_batch_size=64, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys=None,
|
||||
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=8, resize_scale=0.08, distort_color='tf', image_size=224,
|
||||
**kwargs):
|
||||
super(DistributedImageNetRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha, model_init, validation_frequency, print_frequency, n_worker, resize_scale, distort_color,
|
||||
image_size, **kwargs
|
||||
)
|
||||
|
||||
self._num_replicas = kwargs['num_replicas']
|
||||
self._rank = kwargs['rank']
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == ImagenetDataProvider.name():
|
||||
DataProviderClass = ImagenetDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
num_replicas=self._num_replicas, rank=self._rank,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
@@ -0,0 +1,375 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import json
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.optim
|
||||
from tqdm import tqdm
|
||||
|
||||
from ofa_local.utils import get_net_info, cross_entropy_loss_with_soft_target, cross_entropy_with_label_smoothing
|
||||
from ofa_local.utils import AverageMeter, accuracy, write_log, mix_images, mix_labels, init_models
|
||||
from ofa_local.utils import MyRandomResizedCrop
|
||||
|
||||
__all__ = ['RunManager']
|
||||
|
||||
|
||||
class RunManager:
|
||||
|
||||
def __init__(self, path, net, run_config, init=True, measure_latency=None, no_gpu=False):
|
||||
self.path = path
|
||||
self.net = net
|
||||
self.run_config = run_config
|
||||
|
||||
self.best_acc = 0
|
||||
self.start_epoch = 0
|
||||
|
||||
os.makedirs(self.path, exist_ok=True)
|
||||
|
||||
# move network to GPU if available
|
||||
if torch.cuda.is_available() and (not no_gpu):
|
||||
self.device = torch.device('cuda:0')
|
||||
self.net = self.net.to(self.device)
|
||||
cudnn.benchmark = True
|
||||
else:
|
||||
self.device = torch.device('cpu')
|
||||
# initialize model (default)
|
||||
if init:
|
||||
init_models(run_config.model_init)
|
||||
|
||||
# net info
|
||||
net_info = get_net_info(self.net, self.run_config.data_provider.data_shape, measure_latency, True)
|
||||
with open('%s/net_info.txt' % self.path, 'w') as fout:
|
||||
fout.write(json.dumps(net_info, indent=4) + '\n')
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
fout.write(self.network.module_str + '\n')
|
||||
except Exception:
|
||||
pass
|
||||
fout.write('%s\n' % self.run_config.data_provider.train.dataset.transform)
|
||||
fout.write('%s\n' % self.run_config.data_provider.test.dataset.transform)
|
||||
fout.write('%s\n' % self.network)
|
||||
|
||||
# criterion
|
||||
if isinstance(self.run_config.mixup_alpha, float):
|
||||
self.train_criterion = cross_entropy_loss_with_soft_target
|
||||
elif self.run_config.label_smoothing > 0:
|
||||
self.train_criterion = \
|
||||
lambda pred, target: cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing)
|
||||
else:
|
||||
self.train_criterion = nn.CrossEntropyLoss()
|
||||
self.test_criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# optimizer
|
||||
if self.run_config.no_decay_keys:
|
||||
keys = self.run_config.no_decay_keys.split('#')
|
||||
net_params = [
|
||||
self.network.get_parameters(keys, mode='exclude'), # parameters with weight decay
|
||||
self.network.get_parameters(keys, mode='include'), # parameters without weight decay
|
||||
]
|
||||
else:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
net_params = self.network.weight_parameters()
|
||||
except Exception:
|
||||
net_params = []
|
||||
for param in self.network.parameters():
|
||||
if param.requires_grad:
|
||||
net_params.append(param)
|
||||
self.optimizer = self.run_config.build_optimizer(net_params)
|
||||
|
||||
self.net = torch.nn.DataParallel(self.net)
|
||||
|
||||
""" save path and log path """
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self.__dict__.get('_save_path', None) is None:
|
||||
save_path = os.path.join(self.path, 'checkpoint')
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
self.__dict__['_save_path'] = save_path
|
||||
return self.__dict__['_save_path']
|
||||
|
||||
@property
|
||||
def logs_path(self):
|
||||
if self.__dict__.get('_logs_path', None) is None:
|
||||
logs_path = os.path.join(self.path, 'logs')
|
||||
os.makedirs(logs_path, exist_ok=True)
|
||||
self.__dict__['_logs_path'] = logs_path
|
||||
return self.__dict__['_logs_path']
|
||||
|
||||
@property
|
||||
def network(self):
|
||||
return self.net.module if isinstance(self.net, nn.DataParallel) else self.net
|
||||
|
||||
def write_log(self, log_str, prefix='valid', should_print=True, mode='a'):
|
||||
write_log(self.logs_path, log_str, prefix, should_print, mode)
|
||||
|
||||
""" save and load models """
|
||||
|
||||
def save_model(self, checkpoint=None, is_best=False, model_name=None):
|
||||
if checkpoint is None:
|
||||
checkpoint = {'state_dict': self.network.state_dict()}
|
||||
|
||||
if model_name is None:
|
||||
model_name = 'checkpoint.pth.tar'
|
||||
|
||||
checkpoint['dataset'] = self.run_config.dataset # add `dataset` info to the checkpoint
|
||||
latest_fname = os.path.join(self.save_path, 'latest.txt')
|
||||
model_path = os.path.join(self.save_path, model_name)
|
||||
with open(latest_fname, 'w') as fout:
|
||||
fout.write(model_path + '\n')
|
||||
torch.save(checkpoint, model_path)
|
||||
|
||||
if is_best:
|
||||
best_path = os.path.join(self.save_path, 'model_best.pth.tar')
|
||||
torch.save({'state_dict': checkpoint['state_dict']}, best_path)
|
||||
|
||||
def load_model(self, model_fname=None):
|
||||
latest_fname = os.path.join(self.save_path, 'latest.txt')
|
||||
if model_fname is None and os.path.exists(latest_fname):
|
||||
with open(latest_fname, 'r') as fin:
|
||||
model_fname = fin.readline()
|
||||
if model_fname[-1] == '\n':
|
||||
model_fname = model_fname[:-1]
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if model_fname is None or not os.path.exists(model_fname):
|
||||
model_fname = '%s/checkpoint.pth.tar' % self.save_path
|
||||
with open(latest_fname, 'w') as fout:
|
||||
fout.write(model_fname + '\n')
|
||||
print("=> loading checkpoint '{}'".format(model_fname))
|
||||
checkpoint = torch.load(model_fname, map_location='cpu')
|
||||
except Exception:
|
||||
print('fail to load checkpoint from %s' % self.save_path)
|
||||
return {}
|
||||
|
||||
self.network.load_state_dict(checkpoint['state_dict'])
|
||||
if 'epoch' in checkpoint:
|
||||
self.start_epoch = checkpoint['epoch'] + 1
|
||||
if 'best_acc' in checkpoint:
|
||||
self.best_acc = checkpoint['best_acc']
|
||||
if 'optimizer' in checkpoint:
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
||||
print("=> loaded checkpoint '{}'".format(model_fname))
|
||||
return checkpoint
|
||||
|
||||
def save_config(self, extra_run_config=None, extra_net_config=None):
|
||||
""" dump run_config and net_config to the model_folder """
|
||||
run_save_path = os.path.join(self.path, 'run.config')
|
||||
if not os.path.isfile(run_save_path):
|
||||
run_config = self.run_config.config
|
||||
if extra_run_config is not None:
|
||||
run_config.update(extra_run_config)
|
||||
json.dump(run_config, open(run_save_path, 'w'), indent=4)
|
||||
print('Run configs dump to %s' % run_save_path)
|
||||
|
||||
try:
|
||||
net_save_path = os.path.join(self.path, 'net.config')
|
||||
net_config = self.network.config
|
||||
if extra_net_config is not None:
|
||||
net_config.update(extra_net_config)
|
||||
json.dump(net_config, open(net_save_path, 'w'), indent=4)
|
||||
print('Network configs dump to %s' % net_save_path)
|
||||
except Exception:
|
||||
print('%s do not support net config' % type(self.network))
|
||||
|
||||
""" metric related """
|
||||
|
||||
def get_metric_dict(self):
|
||||
return {
|
||||
'top1': AverageMeter(),
|
||||
'top5': AverageMeter(),
|
||||
}
|
||||
|
||||
def update_metric(self, metric_dict, output, labels):
|
||||
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||
metric_dict['top1'].update(acc1[0].item(), output.size(0))
|
||||
metric_dict['top5'].update(acc5[0].item(), output.size(0))
|
||||
|
||||
def get_metric_vals(self, metric_dict, return_dict=False):
|
||||
if return_dict:
|
||||
return {
|
||||
key: metric_dict[key].avg for key in metric_dict
|
||||
}
|
||||
else:
|
||||
return [metric_dict[key].avg for key in metric_dict]
|
||||
|
||||
def get_metric_names(self):
|
||||
return 'top1', 'top5'
|
||||
|
||||
""" train and test """
|
||||
|
||||
def validate(self, epoch=0, is_test=False, run_str='', net=None, data_loader=None, no_logs=False, train_mode=False):
|
||||
if net is None:
|
||||
net = self.net
|
||||
if not isinstance(net, nn.DataParallel):
|
||||
net = nn.DataParallel(net)
|
||||
|
||||
if data_loader is None:
|
||||
data_loader = self.run_config.test_loader if is_test else self.run_config.valid_loader
|
||||
|
||||
if train_mode:
|
||||
net.train()
|
||||
else:
|
||||
net.eval()
|
||||
|
||||
losses = AverageMeter()
|
||||
metric_dict = self.get_metric_dict()
|
||||
|
||||
with torch.no_grad():
|
||||
with tqdm(total=len(data_loader),
|
||||
desc='Validate Epoch #{} {}'.format(epoch + 1, run_str), disable=no_logs) as t:
|
||||
for i, (images, labels) in enumerate(data_loader):
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
# compute output
|
||||
output = net(images)
|
||||
loss = self.test_criterion(output, labels)
|
||||
# measure accuracy and record loss
|
||||
self.update_metric(metric_dict, output, labels)
|
||||
|
||||
losses.update(loss.item(), images.size(0))
|
||||
t.set_postfix({
|
||||
'loss': losses.avg,
|
||||
**self.get_metric_vals(metric_dict, return_dict=True),
|
||||
'img_size': images.size(2),
|
||||
})
|
||||
t.update(1)
|
||||
return losses.avg, self.get_metric_vals(metric_dict)
|
||||
|
||||
def validate_all_resolution(self, epoch=0, is_test=False, net=None):
|
||||
if net is None:
|
||||
net = self.network
|
||||
if isinstance(self.run_config.data_provider.image_size, list):
|
||||
img_size_list, loss_list, top1_list, top5_list = [], [], [], []
|
||||
for img_size in self.run_config.data_provider.image_size:
|
||||
img_size_list.append(img_size)
|
||||
self.run_config.data_provider.assign_active_img_size(img_size)
|
||||
self.reset_running_statistics(net=net)
|
||||
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
|
||||
loss_list.append(loss)
|
||||
top1_list.append(top1)
|
||||
top5_list.append(top5)
|
||||
return img_size_list, loss_list, top1_list, top5_list
|
||||
else:
|
||||
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
|
||||
return [self.run_config.data_provider.active_img_size], [loss], [top1], [top5]
|
||||
|
||||
def train_one_epoch(self, args, epoch, warmup_epochs=0, warmup_lr=0):
|
||||
# switch to train mode
|
||||
self.net.train()
|
||||
MyRandomResizedCrop.EPOCH = epoch # required by elastic resolution
|
||||
|
||||
nBatch = len(self.run_config.train_loader)
|
||||
|
||||
losses = AverageMeter()
|
||||
metric_dict = self.get_metric_dict()
|
||||
data_time = AverageMeter()
|
||||
|
||||
with tqdm(total=nBatch,
|
||||
desc='{} Train Epoch #{}'.format(self.run_config.dataset, epoch + 1)) as t:
|
||||
end = time.time()
|
||||
for i, (images, labels) in enumerate(self.run_config.train_loader):
|
||||
MyRandomResizedCrop.BATCH = i
|
||||
data_time.update(time.time() - end)
|
||||
if epoch < warmup_epochs:
|
||||
new_lr = self.run_config.warmup_adjust_learning_rate(
|
||||
self.optimizer, warmup_epochs * nBatch, nBatch, epoch, i, warmup_lr,
|
||||
)
|
||||
else:
|
||||
new_lr = self.run_config.adjust_learning_rate(self.optimizer, epoch - warmup_epochs, i, nBatch)
|
||||
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
target = labels
|
||||
if isinstance(self.run_config.mixup_alpha, float):
|
||||
# transform data
|
||||
lam = random.betavariate(self.run_config.mixup_alpha, self.run_config.mixup_alpha)
|
||||
images = mix_images(images, lam)
|
||||
labels = mix_labels(
|
||||
labels, lam, self.run_config.data_provider.n_classes, self.run_config.label_smoothing
|
||||
)
|
||||
|
||||
# soft target
|
||||
if args.teacher_model is not None:
|
||||
args.teacher_model.train()
|
||||
with torch.no_grad():
|
||||
soft_logits = args.teacher_model(images).detach()
|
||||
soft_label = F.softmax(soft_logits, dim=1)
|
||||
|
||||
# compute output
|
||||
output = self.net(images)
|
||||
loss = self.train_criterion(output, labels)
|
||||
|
||||
if args.teacher_model is None:
|
||||
loss_type = 'ce'
|
||||
else:
|
||||
if args.kd_type == 'ce':
|
||||
kd_loss = cross_entropy_loss_with_soft_target(output, soft_label)
|
||||
else:
|
||||
kd_loss = F.mse_loss(output, soft_logits)
|
||||
loss = args.kd_ratio * kd_loss + loss
|
||||
loss_type = '%.1fkd+ce' % args.kd_ratio
|
||||
|
||||
# compute gradient and do SGD step
|
||||
self.net.zero_grad() # or self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
# measure accuracy and record loss
|
||||
losses.update(loss.item(), images.size(0))
|
||||
self.update_metric(metric_dict, output, target)
|
||||
|
||||
t.set_postfix({
|
||||
'loss': losses.avg,
|
||||
**self.get_metric_vals(metric_dict, return_dict=True),
|
||||
'img_size': images.size(2),
|
||||
'lr': new_lr,
|
||||
'loss_type': loss_type,
|
||||
'data_time': data_time.avg,
|
||||
})
|
||||
t.update(1)
|
||||
end = time.time()
|
||||
return losses.avg, self.get_metric_vals(metric_dict)
|
||||
|
||||
def train(self, args, warmup_epoch=0, warmup_lr=0):
|
||||
for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epoch):
|
||||
train_loss, (train_top1, train_top5) = self.train_one_epoch(args, epoch, warmup_epoch, warmup_lr)
|
||||
|
||||
if (epoch + 1) % self.run_config.validation_frequency == 0:
|
||||
img_size, val_loss, val_acc, val_acc5 = self.validate_all_resolution(epoch=epoch, is_test=False)
|
||||
|
||||
is_best = np.mean(val_acc) > self.best_acc
|
||||
self.best_acc = max(self.best_acc, np.mean(val_acc))
|
||||
val_log = 'Valid [{0}/{1}]\tloss {2:.3f}\t{5} {3:.3f} ({4:.3f})'. \
|
||||
format(epoch + 1 - warmup_epoch, self.run_config.n_epochs,
|
||||
np.mean(val_loss), np.mean(val_acc), self.best_acc, self.get_metric_names()[0])
|
||||
val_log += '\t{2} {0:.3f}\tTrain {1} {top1:.3f}\tloss {train_loss:.3f}\t'. \
|
||||
format(np.mean(val_acc5), *self.get_metric_names(), top1=train_top1, train_loss=train_loss)
|
||||
for i_s, v_a in zip(img_size, val_acc):
|
||||
val_log += '(%d, %.3f), ' % (i_s, v_a)
|
||||
self.write_log(val_log, prefix='valid', should_print=False)
|
||||
else:
|
||||
is_best = False
|
||||
|
||||
self.save_model({
|
||||
'epoch': epoch,
|
||||
'best_acc': self.best_acc,
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'state_dict': self.network.state_dict(),
|
||||
}, is_best=is_best)
|
||||
|
||||
def reset_running_statistics(self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None):
|
||||
from ofa.imagenet_classification.elastic_nn.utils import set_running_statistics
|
||||
if net is None:
|
||||
net = self.network
|
||||
if data_loader is None:
|
||||
data_loader = self.run_config.random_sub_train_loader(subset_size, subset_batch_size)
|
||||
set_running_statistics(net, data_loader)
|
||||
@@ -0,0 +1,87 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import json
|
||||
import torch
|
||||
|
||||
from ofa_local.utils import download_url
|
||||
from ofa_local.imagenet_classification.networks import get_net_by_name, proxyless_base
|
||||
from ofa_local.imagenet_classification.elastic_nn.networks import OFAMobileNetV3, OFAProxylessNASNets, OFAResNets
|
||||
|
||||
__all__ = [
|
||||
'ofa_specialized', 'ofa_net',
|
||||
'proxylessnas_net', 'proxylessnas_mobile', 'proxylessnas_cpu', 'proxylessnas_gpu',
|
||||
]
|
||||
|
||||
|
||||
def ofa_specialized(net_id, pretrained=True):
|
||||
url_base = 'https://hanlab.mit.edu/files/OnceForAll/ofa_specialized/'
|
||||
net_config = json.load(open(
|
||||
download_url(url_base + net_id + '/net.config', model_dir='.torch/ofa_specialized/%s/' % net_id)
|
||||
))
|
||||
net = get_net_by_name(net_config['name']).build_from_config(net_config)
|
||||
|
||||
image_size = json.load(open(
|
||||
download_url(url_base + net_id + '/run.config', model_dir='.torch/ofa_specialized/%s/' % net_id)
|
||||
))['image_size']
|
||||
|
||||
if pretrained:
|
||||
init = torch.load(
|
||||
download_url(url_base + net_id + '/init', model_dir='.torch/ofa_specialized/%s/' % net_id),
|
||||
map_location='cpu'
|
||||
)['state_dict']
|
||||
net.load_state_dict(init)
|
||||
return net, image_size
|
||||
|
||||
|
||||
def ofa_net(net_id, pretrained=True):
|
||||
if net_id == 'ofa_proxyless_d234_e346_k357_w1.3':
|
||||
net = OFAProxylessNASNets(
|
||||
dropout_rate=0, width_mult=1.3, ks_list=[3, 5, 7], expand_ratio_list=[3, 4, 6], depth_list=[2, 3, 4],
|
||||
)
|
||||
elif net_id == 'ofa_mbv3_d234_e346_k357_w1.0':
|
||||
net = OFAMobileNetV3(
|
||||
dropout_rate=0, width_mult=1.0, ks_list=[3, 5, 7], expand_ratio_list=[3, 4, 6], depth_list=[2, 3, 4],
|
||||
)
|
||||
elif net_id == 'ofa_mbv3_d234_e346_k357_w1.2':
|
||||
net = OFAMobileNetV3(
|
||||
dropout_rate=0, width_mult=1.2, ks_list=[3, 5, 7], expand_ratio_list=[3, 4, 6], depth_list=[2, 3, 4],
|
||||
)
|
||||
elif net_id == 'ofa_resnet50':
|
||||
net = OFAResNets(
|
||||
dropout_rate=0, depth_list=[0, 1, 2], expand_ratio_list=[0.2, 0.25, 0.35], width_mult_list=[0.65, 0.8, 1.0]
|
||||
)
|
||||
net_id = 'ofa_resnet50_d=0+1+2_e=0.2+0.25+0.35_w=0.65+0.8+1.0'
|
||||
else:
|
||||
raise ValueError('Not supported: %s' % net_id)
|
||||
|
||||
if pretrained:
|
||||
url_base = 'https://hanlab.mit.edu/files/OnceForAll/ofa_nets/'
|
||||
init = torch.load(
|
||||
download_url(url_base + net_id, model_dir='.torch/ofa_nets'),
|
||||
map_location='cpu')['state_dict']
|
||||
net.load_state_dict(init)
|
||||
return net
|
||||
|
||||
|
||||
def proxylessnas_net(net_id, pretrained=True):
|
||||
net = proxyless_base(
|
||||
net_config='https://hanlab.mit.edu/files/proxylessNAS/%s.config' % net_id,
|
||||
)
|
||||
if pretrained:
|
||||
net.load_state_dict(torch.load(
|
||||
download_url('https://hanlab.mit.edu/files/proxylessNAS/%s.pth' % net_id), map_location='cpu'
|
||||
)['state_dict'])
|
||||
|
||||
|
||||
def proxylessnas_mobile(pretrained=True):
|
||||
return proxylessnas_net('proxyless_mobile', pretrained)
|
||||
|
||||
|
||||
def proxylessnas_cpu(pretrained=True):
|
||||
return proxylessnas_net('proxyless_cpu', pretrained)
|
||||
|
||||
|
||||
def proxylessnas_gpu(pretrained=True):
|
||||
return proxylessnas_net('proxyless_gpu', pretrained)
|
||||
@@ -0,0 +1,7 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from .acc_dataset import *
|
||||
from .acc_predictor import *
|
||||
from .arch_encoder import *
|
||||
@@ -0,0 +1,181 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.utils.data
|
||||
|
||||
from ofa.utils import list_mean
|
||||
|
||||
__all__ = ['net_setting2id', 'net_id2setting', 'AccuracyDataset']
|
||||
|
||||
|
||||
def net_setting2id(net_setting):
|
||||
return json.dumps(net_setting)
|
||||
|
||||
|
||||
def net_id2setting(net_id):
|
||||
return json.loads(net_id)
|
||||
|
||||
|
||||
class RegDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, inputs, targets):
|
||||
super(RegDataset, self).__init__()
|
||||
self.inputs = inputs
|
||||
self.targets = targets
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.inputs[index], self.targets[index]
|
||||
|
||||
def __len__(self):
|
||||
return self.inputs.size(0)
|
||||
|
||||
|
||||
class AccuracyDataset:
|
||||
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
os.makedirs(self.path, exist_ok=True)
|
||||
|
||||
@property
|
||||
def net_id_path(self):
|
||||
return os.path.join(self.path, 'net_id.dict')
|
||||
|
||||
@property
|
||||
def acc_src_folder(self):
|
||||
return os.path.join(self.path, 'src')
|
||||
|
||||
@property
|
||||
def acc_dict_path(self):
|
||||
return os.path.join(self.path, 'acc.dict')
|
||||
|
||||
# TODO: support parallel building
|
||||
def build_acc_dataset(self, run_manager, ofa_network, n_arch=1000, image_size_list=None):
|
||||
# load net_id_list, random sample if not exist
|
||||
if os.path.isfile(self.net_id_path):
|
||||
net_id_list = json.load(open(self.net_id_path))
|
||||
else:
|
||||
net_id_list = set()
|
||||
while len(net_id_list) < n_arch:
|
||||
net_setting = ofa_network.sample_active_subnet()
|
||||
net_id = net_setting2id(net_setting)
|
||||
net_id_list.add(net_id)
|
||||
net_id_list = list(net_id_list)
|
||||
net_id_list.sort()
|
||||
json.dump(net_id_list, open(self.net_id_path, 'w'), indent=4)
|
||||
|
||||
image_size_list = [128, 160, 192, 224] if image_size_list is None else image_size_list
|
||||
|
||||
with tqdm(total=len(net_id_list) * len(image_size_list), desc='Building Acc Dataset') as t:
|
||||
for image_size in image_size_list:
|
||||
# load val dataset into memory
|
||||
val_dataset = []
|
||||
run_manager.run_config.data_provider.assign_active_img_size(image_size)
|
||||
for images, labels in run_manager.run_config.valid_loader:
|
||||
val_dataset.append((images, labels))
|
||||
# save path
|
||||
os.makedirs(self.acc_src_folder, exist_ok=True)
|
||||
acc_save_path = os.path.join(self.acc_src_folder, '%d.dict' % image_size)
|
||||
acc_dict = {}
|
||||
# load existing acc dict
|
||||
if os.path.isfile(acc_save_path):
|
||||
existing_acc_dict = json.load(open(acc_save_path, 'r'))
|
||||
else:
|
||||
existing_acc_dict = {}
|
||||
for net_id in net_id_list:
|
||||
net_setting = net_id2setting(net_id)
|
||||
key = net_setting2id({**net_setting, 'image_size': image_size})
|
||||
if key in existing_acc_dict:
|
||||
acc_dict[key] = existing_acc_dict[key]
|
||||
t.set_postfix({
|
||||
'net_id': net_id,
|
||||
'image_size': image_size,
|
||||
'info_val': acc_dict[key],
|
||||
'status': 'loading',
|
||||
})
|
||||
t.update()
|
||||
continue
|
||||
ofa_network.set_active_subnet(**net_setting)
|
||||
run_manager.reset_running_statistics(ofa_network)
|
||||
net_setting_str = ','.join(['%s_%s' % (
|
||||
key, '%.1f' % list_mean(val) if isinstance(val, list) else val
|
||||
) for key, val in net_setting.items()])
|
||||
loss, (top1, top5) = run_manager.validate(
|
||||
run_str=net_setting_str, net=ofa_network, data_loader=val_dataset, no_logs=True,
|
||||
)
|
||||
info_val = top1
|
||||
|
||||
t.set_postfix({
|
||||
'net_id': net_id,
|
||||
'image_size': image_size,
|
||||
'info_val': info_val,
|
||||
})
|
||||
t.update()
|
||||
|
||||
acc_dict.update({
|
||||
key: info_val
|
||||
})
|
||||
json.dump(acc_dict, open(acc_save_path, 'w'), indent=4)
|
||||
|
||||
def merge_acc_dataset(self, image_size_list=None):
|
||||
# load existing data
|
||||
merged_acc_dict = {}
|
||||
for fname in os.listdir(self.acc_src_folder):
|
||||
if '.dict' not in fname:
|
||||
continue
|
||||
image_size = int(fname.split('.dict')[0])
|
||||
if image_size_list is not None and image_size not in image_size_list:
|
||||
print('Skip ', fname)
|
||||
continue
|
||||
full_path = os.path.join(self.acc_src_folder, fname)
|
||||
partial_acc_dict = json.load(open(full_path))
|
||||
merged_acc_dict.update(partial_acc_dict)
|
||||
print('loaded %s' % full_path)
|
||||
json.dump(merged_acc_dict, open(self.acc_dict_path, 'w'), indent=4)
|
||||
return merged_acc_dict
|
||||
|
||||
def build_acc_data_loader(self, arch_encoder, n_training_sample=None, batch_size=256, n_workers=16):
|
||||
# load data
|
||||
acc_dict = json.load(open(self.acc_dict_path))
|
||||
X_all = []
|
||||
Y_all = []
|
||||
with tqdm(total=len(acc_dict), desc='Loading data') as t:
|
||||
for k, v in acc_dict.items():
|
||||
dic = json.loads(k)
|
||||
X_all.append(arch_encoder.arch2feature(dic))
|
||||
Y_all.append(v / 100.) # range: 0 - 1
|
||||
t.update()
|
||||
base_acc = np.mean(Y_all)
|
||||
# convert to torch tensor
|
||||
X_all = torch.tensor(X_all, dtype=torch.float)
|
||||
Y_all = torch.tensor(Y_all)
|
||||
|
||||
# random shuffle
|
||||
shuffle_idx = torch.randperm(len(X_all))
|
||||
X_all = X_all[shuffle_idx]
|
||||
Y_all = Y_all[shuffle_idx]
|
||||
|
||||
# split data
|
||||
idx = X_all.size(0) // 5 * 4 if n_training_sample is None else n_training_sample
|
||||
val_idx = X_all.size(0) // 5 * 4
|
||||
X_train, Y_train = X_all[:idx], Y_all[:idx]
|
||||
X_test, Y_test = X_all[val_idx:], Y_all[val_idx:]
|
||||
print('Train Size: %d,' % len(X_train), 'Valid Size: %d' % len(X_test))
|
||||
|
||||
# build data loader
|
||||
train_dataset = RegDataset(X_train, Y_train)
|
||||
val_dataset = RegDataset(X_test, Y_test)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True, pin_memory=False, num_workers=n_workers
|
||||
)
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
val_dataset, batch_size=batch_size, shuffle=False, pin_memory=False, num_workers=n_workers
|
||||
)
|
||||
|
||||
return train_loader, valid_loader, base_acc
|
||||
@@ -0,0 +1,50 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ['AccuracyPredictor']
|
||||
|
||||
|
||||
class AccuracyPredictor(nn.Module):
|
||||
|
||||
def __init__(self, arch_encoder, hidden_size=400, n_layers=3,
|
||||
checkpoint_path=None, device='cuda:0'):
|
||||
super(AccuracyPredictor, self).__init__()
|
||||
self.arch_encoder = arch_encoder
|
||||
self.hidden_size = hidden_size
|
||||
self.n_layers = n_layers
|
||||
self.device = device
|
||||
|
||||
# build layers
|
||||
layers = []
|
||||
for i in range(self.n_layers):
|
||||
layers.append(nn.Sequential(
|
||||
nn.Linear(self.arch_encoder.n_dim if i == 0 else self.hidden_size, self.hidden_size),
|
||||
nn.ReLU(inplace=True),
|
||||
))
|
||||
layers.append(nn.Linear(self.hidden_size, 1, bias=False))
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.base_acc = nn.Parameter(torch.zeros(1, device=self.device), requires_grad=False)
|
||||
|
||||
if checkpoint_path is not None and os.path.exists(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
checkpoint = checkpoint['state_dict']
|
||||
self.load_state_dict(checkpoint)
|
||||
print('Loaded checkpoint from %s' % checkpoint_path)
|
||||
|
||||
self.layers = self.layers.to(self.device)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.layers(x).squeeze()
|
||||
return y + self.base_acc
|
||||
|
||||
def predict_acc(self, arch_dict_list):
|
||||
X = [self.arch_encoder.arch2feature(arch_dict) for arch_dict in arch_dict_list]
|
||||
X = torch.tensor(np.array(X)).float().to(self.device)
|
||||
return self.forward(X)
|
||||
@@ -0,0 +1,315 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
from ofa.imagenet_classification.networks import ResNets
|
||||
|
||||
__all__ = ['MobileNetArchEncoder', 'ResNetArchEncoder']
|
||||
|
||||
|
||||
class MobileNetArchEncoder:
|
||||
SPACE_TYPE = 'mbv3'
|
||||
|
||||
def __init__(self, image_size_list=None, ks_list=None, expand_list=None, depth_list=None, n_stage=None):
|
||||
self.image_size_list = [224] if image_size_list is None else image_size_list
|
||||
self.ks_list = [3, 5, 7] if ks_list is None else ks_list
|
||||
self.expand_list = [3, 4, 6] if expand_list is None else [int(expand) for expand in expand_list]
|
||||
self.depth_list = [2, 3, 4] if depth_list is None else depth_list
|
||||
if n_stage is not None:
|
||||
self.n_stage = n_stage
|
||||
elif self.SPACE_TYPE == 'mbv2':
|
||||
self.n_stage = 6
|
||||
elif self.SPACE_TYPE == 'mbv3':
|
||||
self.n_stage = 5
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# build info dict
|
||||
self.n_dim = 0
|
||||
self.r_info = dict(id2val={}, val2id={}, L=[], R=[])
|
||||
self._build_info_dict(target='r')
|
||||
|
||||
self.k_info = dict(id2val=[], val2id=[], L=[], R=[])
|
||||
self.e_info = dict(id2val=[], val2id=[], L=[], R=[])
|
||||
self._build_info_dict(target='k')
|
||||
self._build_info_dict(target='e')
|
||||
|
||||
@property
|
||||
def max_n_blocks(self):
|
||||
if self.SPACE_TYPE == 'mbv3':
|
||||
return self.n_stage * max(self.depth_list)
|
||||
elif self.SPACE_TYPE == 'mbv2':
|
||||
return (self.n_stage - 1) * max(self.depth_list) + 1
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _build_info_dict(self, target):
|
||||
if target == 'r':
|
||||
target_dict = self.r_info
|
||||
target_dict['L'].append(self.n_dim)
|
||||
for img_size in self.image_size_list:
|
||||
target_dict['val2id'][img_size] = self.n_dim
|
||||
target_dict['id2val'][self.n_dim] = img_size
|
||||
self.n_dim += 1
|
||||
target_dict['R'].append(self.n_dim)
|
||||
else:
|
||||
if target == 'k':
|
||||
target_dict = self.k_info
|
||||
choices = self.ks_list
|
||||
elif target == 'e':
|
||||
target_dict = self.e_info
|
||||
choices = self.expand_list
|
||||
else:
|
||||
raise NotImplementedError
|
||||
for i in range(self.max_n_blocks):
|
||||
target_dict['val2id'].append({})
|
||||
target_dict['id2val'].append({})
|
||||
target_dict['L'].append(self.n_dim)
|
||||
for k in choices:
|
||||
target_dict['val2id'][i][k] = self.n_dim
|
||||
target_dict['id2val'][i][self.n_dim] = k
|
||||
self.n_dim += 1
|
||||
target_dict['R'].append(self.n_dim)
|
||||
|
||||
def arch2feature(self, arch_dict):
|
||||
ks, e, d, r = arch_dict['ks'], arch_dict['e'], arch_dict['d'], arch_dict['image_size']
|
||||
|
||||
feature = np.zeros(self.n_dim)
|
||||
for i in range(self.max_n_blocks):
|
||||
nowd = i % max(self.depth_list)
|
||||
stg = i // max(self.depth_list)
|
||||
if nowd < d[stg]:
|
||||
feature[self.k_info['val2id'][i][ks[i]]] = 1
|
||||
feature[self.e_info['val2id'][i][e[i]]] = 1
|
||||
feature[self.r_info['val2id'][r]] = 1
|
||||
return feature
|
||||
|
||||
def feature2arch(self, feature):
|
||||
img_sz = self.r_info['id2val'][
|
||||
int(np.argmax(feature[self.r_info['L'][0]:self.r_info['R'][0]])) + self.r_info['L'][0]
|
||||
]
|
||||
assert img_sz in self.image_size_list
|
||||
arch_dict = {'ks': [], 'e': [], 'd': [], 'image_size': img_sz}
|
||||
|
||||
d = 0
|
||||
for i in range(self.max_n_blocks):
|
||||
skip = True
|
||||
for j in range(self.k_info['L'][i], self.k_info['R'][i]):
|
||||
if feature[j] == 1:
|
||||
arch_dict['ks'].append(self.k_info['id2val'][i][j])
|
||||
skip = False
|
||||
break
|
||||
|
||||
for j in range(self.e_info['L'][i], self.e_info['R'][i]):
|
||||
if feature[j] == 1:
|
||||
arch_dict['e'].append(self.e_info['id2val'][i][j])
|
||||
assert not skip
|
||||
skip = False
|
||||
break
|
||||
|
||||
if skip:
|
||||
arch_dict['e'].append(0)
|
||||
arch_dict['ks'].append(0)
|
||||
else:
|
||||
d += 1
|
||||
|
||||
if (i + 1) % max(self.depth_list) == 0 or (i + 1) == self.max_n_blocks:
|
||||
arch_dict['d'].append(d)
|
||||
d = 0
|
||||
return arch_dict
|
||||
|
||||
def random_sample_arch(self):
|
||||
return {
|
||||
'ks': random.choices(self.ks_list, k=self.max_n_blocks),
|
||||
'e': random.choices(self.expand_list, k=self.max_n_blocks),
|
||||
'd': random.choices(self.depth_list, k=self.n_stage),
|
||||
'image_size': random.choice(self.image_size_list)
|
||||
}
|
||||
|
||||
def mutate_resolution(self, arch_dict, mutate_prob):
|
||||
if random.random() < mutate_prob:
|
||||
arch_dict['image_size'] = random.choice(self.image_size_list)
|
||||
return arch_dict
|
||||
|
||||
def mutate_arch(self, arch_dict, mutate_prob):
|
||||
for i in range(self.max_n_blocks):
|
||||
if random.random() < mutate_prob:
|
||||
arch_dict['ks'][i] = random.choice(self.ks_list)
|
||||
arch_dict['e'][i] = random.choice(self.expand_list)
|
||||
|
||||
for i in range(self.n_stage):
|
||||
if random.random() < mutate_prob:
|
||||
arch_dict['d'][i] = random.choice(self.depth_list)
|
||||
return arch_dict
|
||||
|
||||
|
||||
class ResNetArchEncoder:
|
||||
|
||||
def __init__(self, image_size_list=None, depth_list=None, expand_list=None, width_mult_list=None,
|
||||
base_depth_list=None):
|
||||
self.image_size_list = [224] if image_size_list is None else image_size_list
|
||||
self.expand_list = [0.2, 0.25, 0.35] if expand_list is None else expand_list
|
||||
self.depth_list = [0, 1, 2] if depth_list is None else depth_list
|
||||
self.width_mult_list = [0.65, 0.8, 1.0] if width_mult_list is None else width_mult_list
|
||||
|
||||
self.base_depth_list = ResNets.BASE_DEPTH_LIST if base_depth_list is None else base_depth_list
|
||||
|
||||
"""" build info dict """
|
||||
self.n_dim = 0
|
||||
# resolution
|
||||
self.r_info = dict(id2val={}, val2id={}, L=[], R=[])
|
||||
self._build_info_dict(target='r')
|
||||
# input stem skip
|
||||
self.input_stem_d_info = dict(id2val={}, val2id={}, L=[], R=[])
|
||||
self._build_info_dict(target='input_stem_d')
|
||||
# width_mult
|
||||
self.width_mult_info = dict(id2val=[], val2id=[], L=[], R=[])
|
||||
self._build_info_dict(target='width_mult')
|
||||
# expand ratio
|
||||
self.e_info = dict(id2val=[], val2id=[], L=[], R=[])
|
||||
self._build_info_dict(target='e')
|
||||
|
||||
@property
|
||||
def n_stage(self):
|
||||
return len(self.base_depth_list)
|
||||
|
||||
@property
|
||||
def max_n_blocks(self):
|
||||
return sum(self.base_depth_list) + self.n_stage * max(self.depth_list)
|
||||
|
||||
def _build_info_dict(self, target):
|
||||
if target == 'r':
|
||||
target_dict = self.r_info
|
||||
target_dict['L'].append(self.n_dim)
|
||||
for img_size in self.image_size_list:
|
||||
target_dict['val2id'][img_size] = self.n_dim
|
||||
target_dict['id2val'][self.n_dim] = img_size
|
||||
self.n_dim += 1
|
||||
target_dict['R'].append(self.n_dim)
|
||||
elif target == 'input_stem_d':
|
||||
target_dict = self.input_stem_d_info
|
||||
target_dict['L'].append(self.n_dim)
|
||||
for skip in [0, 1]:
|
||||
target_dict['val2id'][skip] = self.n_dim
|
||||
target_dict['id2val'][self.n_dim] = skip
|
||||
self.n_dim += 1
|
||||
target_dict['R'].append(self.n_dim)
|
||||
elif target == 'e':
|
||||
target_dict = self.e_info
|
||||
choices = self.expand_list
|
||||
for i in range(self.max_n_blocks):
|
||||
target_dict['val2id'].append({})
|
||||
target_dict['id2val'].append({})
|
||||
target_dict['L'].append(self.n_dim)
|
||||
for e in choices:
|
||||
target_dict['val2id'][i][e] = self.n_dim
|
||||
target_dict['id2val'][i][self.n_dim] = e
|
||||
self.n_dim += 1
|
||||
target_dict['R'].append(self.n_dim)
|
||||
elif target == 'width_mult':
|
||||
target_dict = self.width_mult_info
|
||||
choices = list(range(len(self.width_mult_list)))
|
||||
for i in range(self.n_stage + 2):
|
||||
target_dict['val2id'].append({})
|
||||
target_dict['id2val'].append({})
|
||||
target_dict['L'].append(self.n_dim)
|
||||
for w in choices:
|
||||
target_dict['val2id'][i][w] = self.n_dim
|
||||
target_dict['id2val'][i][self.n_dim] = w
|
||||
self.n_dim += 1
|
||||
target_dict['R'].append(self.n_dim)
|
||||
|
||||
def arch2feature(self, arch_dict):
|
||||
d, e, w, r = arch_dict['d'], arch_dict['e'], arch_dict['w'], arch_dict['image_size']
|
||||
input_stem_skip = 1 if d[0] > 0 else 0
|
||||
d = d[1:]
|
||||
|
||||
feature = np.zeros(self.n_dim)
|
||||
feature[self.r_info['val2id'][r]] = 1
|
||||
feature[self.input_stem_d_info['val2id'][input_stem_skip]] = 1
|
||||
for i in range(self.n_stage + 2):
|
||||
feature[self.width_mult_info['val2id'][i][w[i]]] = 1
|
||||
|
||||
start_pt = 0
|
||||
for i, base_depth in enumerate(self.base_depth_list):
|
||||
depth = base_depth + d[i]
|
||||
for j in range(start_pt, start_pt + depth):
|
||||
feature[self.e_info['val2id'][j][e[j]]] = 1
|
||||
start_pt += max(self.depth_list) + base_depth
|
||||
|
||||
return feature
|
||||
|
||||
def feature2arch(self, feature):
|
||||
img_sz = self.r_info['id2val'][
|
||||
int(np.argmax(feature[self.r_info['L'][0]:self.r_info['R'][0]])) + self.r_info['L'][0]
|
||||
]
|
||||
input_stem_skip = self.input_stem_d_info['id2val'][
|
||||
int(np.argmax(feature[self.input_stem_d_info['L'][0]:self.input_stem_d_info['R'][0]])) +
|
||||
self.input_stem_d_info['L'][0]
|
||||
] * 2
|
||||
assert img_sz in self.image_size_list
|
||||
arch_dict = {'d': [input_stem_skip], 'e': [], 'w': [], 'image_size': img_sz}
|
||||
|
||||
for i in range(self.n_stage + 2):
|
||||
arch_dict['w'].append(
|
||||
self.width_mult_info['id2val'][i][
|
||||
int(np.argmax(feature[self.width_mult_info['L'][i]:self.width_mult_info['R'][i]])) +
|
||||
self.width_mult_info['L'][i]
|
||||
]
|
||||
)
|
||||
|
||||
d = 0
|
||||
skipped = 0
|
||||
stage_id = 0
|
||||
for i in range(self.max_n_blocks):
|
||||
skip = True
|
||||
for j in range(self.e_info['L'][i], self.e_info['R'][i]):
|
||||
if feature[j] == 1:
|
||||
arch_dict['e'].append(self.e_info['id2val'][i][j])
|
||||
skip = False
|
||||
break
|
||||
if skip:
|
||||
arch_dict['e'].append(0)
|
||||
skipped += 1
|
||||
else:
|
||||
d += 1
|
||||
|
||||
if i + 1 == self.max_n_blocks or (skipped + d) % \
|
||||
(max(self.depth_list) + self.base_depth_list[stage_id]) == 0:
|
||||
arch_dict['d'].append(d - self.base_depth_list[stage_id])
|
||||
d, skipped = 0, 0
|
||||
stage_id += 1
|
||||
return arch_dict
|
||||
|
||||
def random_sample_arch(self):
|
||||
return {
|
||||
'd': [random.choice([0, 2])] + random.choices(self.depth_list, k=self.n_stage),
|
||||
'e': random.choices(self.expand_list, k=self.max_n_blocks),
|
||||
'w': random.choices(list(range(len(self.width_mult_list))), k=self.n_stage + 2),
|
||||
'image_size': random.choice(self.image_size_list)
|
||||
}
|
||||
|
||||
def mutate_resolution(self, arch_dict, mutate_prob):
|
||||
if random.random() < mutate_prob:
|
||||
arch_dict['image_size'] = random.choice(self.image_size_list)
|
||||
return arch_dict
|
||||
|
||||
def mutate_arch(self, arch_dict, mutate_prob):
|
||||
# input stem skip
|
||||
if random.random() < mutate_prob:
|
||||
arch_dict['d'][0] = random.choice([0, 2])
|
||||
# depth
|
||||
for i in range(1, len(arch_dict['d'])):
|
||||
if random.random() < mutate_prob:
|
||||
arch_dict['d'][i] = random.choice(self.depth_list)
|
||||
# width_mult
|
||||
for i in range(len(arch_dict['w'])):
|
||||
if random.random() < mutate_prob:
|
||||
arch_dict['w'][i] = random.choice(list(range(len(self.width_mult_list))))
|
||||
# expand ratio
|
||||
for i in range(len(arch_dict['e'])):
|
||||
if random.random() < mutate_prob:
|
||||
arch_dict['e'][i] = random.choice(self.expand_list)
|
||||
@@ -0,0 +1,71 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import os
|
||||
import copy
|
||||
from .latency_lookup_table import *
|
||||
|
||||
|
||||
class BaseEfficiencyModel:
|
||||
|
||||
def __init__(self, ofa_net):
|
||||
self.ofa_net = ofa_net
|
||||
|
||||
def get_active_subnet_config(self, arch_dict):
|
||||
arch_dict = copy.deepcopy(arch_dict)
|
||||
image_size = arch_dict.pop('image_size')
|
||||
self.ofa_net.set_active_subnet(**arch_dict)
|
||||
active_net_config = self.ofa_net.get_active_net_config()
|
||||
return active_net_config, image_size
|
||||
|
||||
def get_efficiency(self, arch_dict):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProxylessNASFLOPsModel(BaseEfficiencyModel):
|
||||
|
||||
def get_efficiency(self, arch_dict):
|
||||
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
|
||||
return ProxylessNASLatencyTable.count_flops_given_config(active_net_config, image_size)
|
||||
|
||||
|
||||
class Mbv3FLOPsModel(BaseEfficiencyModel):
|
||||
|
||||
def get_efficiency(self, arch_dict):
|
||||
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
|
||||
return MBv3LatencyTable.count_flops_given_config(active_net_config, image_size)
|
||||
|
||||
|
||||
class ResNet50FLOPsModel(BaseEfficiencyModel):
|
||||
|
||||
def get_efficiency(self, arch_dict):
|
||||
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
|
||||
return ResNet50LatencyTable.count_flops_given_config(active_net_config, image_size)
|
||||
|
||||
class ProxylessNASLatencyModel(BaseEfficiencyModel):
|
||||
|
||||
def __init__(self, ofa_net, lookup_table_path_dict):
|
||||
super(ProxylessNASLatencyModel, self).__init__(ofa_net)
|
||||
self.latency_tables = {}
|
||||
for image_size, path in lookup_table_path_dict.items():
|
||||
self.latency_tables[image_size] = ProxylessNASLatencyTable(
|
||||
local_dir='/tmp/.ofa_latency_tools/', url=os.path.join(path, '%d_lookup_table.yaml' % image_size))
|
||||
|
||||
def get_efficiency(self, arch_dict):
|
||||
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
|
||||
return self.latency_tables[image_size].predict_network_latency_given_config(active_net_config, image_size)
|
||||
|
||||
|
||||
class Mbv3LatencyModel(BaseEfficiencyModel):
|
||||
|
||||
def __init__(self, ofa_net, lookup_table_path_dict):
|
||||
super(Mbv3LatencyModel, self).__init__(ofa_net)
|
||||
self.latency_tables = {}
|
||||
for image_size, path in lookup_table_path_dict.items():
|
||||
self.latency_tables[image_size] = MBv3LatencyTable(
|
||||
local_dir='/tmp/.ofa_latency_tools/', url=os.path.join(path, '%d_lookup_table.yaml' % image_size))
|
||||
|
||||
def get_efficiency(self, arch_dict):
|
||||
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
|
||||
return self.latency_tables[image_size].predict_network_latency_given_config(active_net_config, image_size)
|
||||
@@ -0,0 +1,387 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import yaml
|
||||
from ofa.utils import download_url, make_divisible, MyNetwork
|
||||
|
||||
__all__ = ['count_conv_flop', 'ProxylessNASLatencyTable', 'MBv3LatencyTable', 'ResNet50LatencyTable']
|
||||
|
||||
|
||||
def count_conv_flop(out_size, in_channels, out_channels, kernel_size, groups):
|
||||
out_h = out_w = out_size
|
||||
delta_ops = in_channels * out_channels * kernel_size * kernel_size * out_h * out_w / groups
|
||||
return delta_ops
|
||||
|
||||
|
||||
class LatencyTable(object):
|
||||
|
||||
def __init__(self, local_dir='~/.ofa/latency_tools/',
|
||||
url='https://hanlab.mit.edu/files/proxylessNAS/LatencyTools/mobile_trim.yaml'):
|
||||
if url.startswith('http'):
|
||||
fname = download_url(url, local_dir, overwrite=True)
|
||||
else:
|
||||
fname = url
|
||||
with open(fname, 'r') as fp:
|
||||
self.lut = yaml.load(fp)
|
||||
|
||||
@staticmethod
|
||||
def repr_shape(shape):
|
||||
if isinstance(shape, (list, tuple)):
|
||||
return 'x'.join(str(_) for _ in shape)
|
||||
elif isinstance(shape, str):
|
||||
return shape
|
||||
else:
|
||||
return TypeError
|
||||
|
||||
def query(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def predict_network_latency(self, net, image_size):
|
||||
raise NotImplementedError
|
||||
|
||||
def predict_network_latency_given_config(self, net_config, image_size):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def count_flops_given_config(net_config, image_size=224):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProxylessNASLatencyTable(LatencyTable):
|
||||
|
||||
def query(self, l_type: str, input_shape, output_shape, expand=None, ks=None, stride=None, id_skip=None):
|
||||
"""
|
||||
:param l_type:
|
||||
Layer type must be one of the followings
|
||||
1. `Conv`: The initial 3x3 conv with stride 2.
|
||||
2. `Conv_1`: feature_mix_layer
|
||||
3. `Logits`: All operations after `Conv_1`.
|
||||
4. `expanded_conv`: MobileInvertedResidual
|
||||
:param input_shape: input shape (h, w, #channels)
|
||||
:param output_shape: output shape (h, w, #channels)
|
||||
:param expand: expansion ratio
|
||||
:param ks: kernel size
|
||||
:param stride:
|
||||
:param id_skip: indicate whether has the residual connection
|
||||
"""
|
||||
infos = [l_type, 'input:%s' % self.repr_shape(input_shape), 'output:%s' % self.repr_shape(output_shape), ]
|
||||
|
||||
if l_type in ('expanded_conv',):
|
||||
assert None not in (expand, ks, stride, id_skip)
|
||||
infos += ['expand:%d' % expand, 'kernel:%d' % ks, 'stride:%d' % stride, 'idskip:%d' % id_skip]
|
||||
key = '-'.join(infos)
|
||||
return self.lut[key]['mean']
|
||||
|
||||
def predict_network_latency(self, net, image_size=224):
|
||||
predicted_latency = 0
|
||||
# first conv
|
||||
predicted_latency += self.query(
|
||||
'Conv', [image_size, image_size, 3],
|
||||
[(image_size + 1) // 2, (image_size + 1) // 2, net.first_conv.out_channels]
|
||||
)
|
||||
# blocks
|
||||
fsize = (image_size + 1) // 2
|
||||
for block in net.blocks:
|
||||
mb_conv = block.conv
|
||||
shortcut = block.shortcut
|
||||
|
||||
if mb_conv is None:
|
||||
continue
|
||||
if shortcut is None:
|
||||
idskip = 0
|
||||
else:
|
||||
idskip = 1
|
||||
out_fz = int((fsize - 1) / mb_conv.stride + 1) # fsize // mb_conv.stride
|
||||
block_latency = self.query(
|
||||
'expanded_conv', [fsize, fsize, mb_conv.in_channels], [out_fz, out_fz, mb_conv.out_channels],
|
||||
expand=mb_conv.expand_ratio, ks=mb_conv.kernel_size, stride=mb_conv.stride, id_skip=idskip
|
||||
)
|
||||
predicted_latency += block_latency
|
||||
fsize = out_fz
|
||||
# feature mix layer
|
||||
predicted_latency += self.query(
|
||||
'Conv_1', [fsize, fsize, net.feature_mix_layer.in_channels],
|
||||
[fsize, fsize, net.feature_mix_layer.out_channels]
|
||||
)
|
||||
# classifier
|
||||
predicted_latency += self.query(
|
||||
'Logits', [fsize, fsize, net.classifier.in_features], [net.classifier.out_features] # 1000
|
||||
)
|
||||
return predicted_latency
|
||||
|
||||
def predict_network_latency_given_config(self, net_config, image_size=224):
|
||||
predicted_latency = 0
|
||||
# first conv
|
||||
predicted_latency += self.query(
|
||||
'Conv', [image_size, image_size, 3],
|
||||
[(image_size + 1) // 2, (image_size + 1) // 2, net_config['first_conv']['out_channels']]
|
||||
)
|
||||
# blocks
|
||||
fsize = (image_size + 1) // 2
|
||||
for block in net_config['blocks']:
|
||||
mb_conv = block['mobile_inverted_conv'] if 'mobile_inverted_conv' in block else block['conv']
|
||||
shortcut = block['shortcut']
|
||||
|
||||
if mb_conv is None:
|
||||
continue
|
||||
if shortcut is None:
|
||||
idskip = 0
|
||||
else:
|
||||
idskip = 1
|
||||
out_fz = int((fsize - 1) / mb_conv['stride'] + 1)
|
||||
block_latency = self.query(
|
||||
'expanded_conv', [fsize, fsize, mb_conv['in_channels']], [out_fz, out_fz, mb_conv['out_channels']],
|
||||
expand=mb_conv['expand_ratio'], ks=mb_conv['kernel_size'], stride=mb_conv['stride'], id_skip=idskip
|
||||
)
|
||||
predicted_latency += block_latency
|
||||
fsize = out_fz
|
||||
# feature mix layer
|
||||
predicted_latency += self.query(
|
||||
'Conv_1', [fsize, fsize, net_config['feature_mix_layer']['in_channels']],
|
||||
[fsize, fsize, net_config['feature_mix_layer']['out_channels']]
|
||||
)
|
||||
# classifier
|
||||
predicted_latency += self.query(
|
||||
'Logits', [fsize, fsize, net_config['classifier']['in_features']],
|
||||
[net_config['classifier']['out_features']] # 1000
|
||||
)
|
||||
return predicted_latency
|
||||
|
||||
@staticmethod
|
||||
def count_flops_given_config(net_config, image_size=224):
|
||||
flops = 0
|
||||
# first conv
|
||||
flops += count_conv_flop((image_size + 1) // 2, 3, net_config['first_conv']['out_channels'], 3, 1)
|
||||
# blocks
|
||||
fsize = (image_size + 1) // 2
|
||||
for block in net_config['blocks']:
|
||||
mb_conv = block['mobile_inverted_conv'] if 'mobile_inverted_conv' in block else block['conv']
|
||||
if mb_conv is None:
|
||||
continue
|
||||
out_fz = int((fsize - 1) / mb_conv['stride'] + 1)
|
||||
if mb_conv['mid_channels'] is None:
|
||||
mb_conv['mid_channels'] = round(mb_conv['in_channels'] * mb_conv['expand_ratio'])
|
||||
if mb_conv['expand_ratio'] != 1:
|
||||
# inverted bottleneck
|
||||
flops += count_conv_flop(fsize, mb_conv['in_channels'], mb_conv['mid_channels'], 1, 1)
|
||||
# depth conv
|
||||
flops += count_conv_flop(out_fz, mb_conv['mid_channels'], mb_conv['mid_channels'],
|
||||
mb_conv['kernel_size'], mb_conv['mid_channels'])
|
||||
# point linear
|
||||
flops += count_conv_flop(out_fz, mb_conv['mid_channels'], mb_conv['out_channels'], 1, 1)
|
||||
fsize = out_fz
|
||||
# feature mix layer
|
||||
flops += count_conv_flop(fsize, net_config['feature_mix_layer']['in_channels'],
|
||||
net_config['feature_mix_layer']['out_channels'], 1, 1)
|
||||
# classifier
|
||||
flops += count_conv_flop(1, net_config['classifier']['in_features'],
|
||||
net_config['classifier']['out_features'], 1, 1)
|
||||
return flops / 1e6 # MFLOPs
|
||||
|
||||
|
||||
class MBv3LatencyTable(LatencyTable):
|
||||
|
||||
def query(self, l_type: str, input_shape, output_shape, mid=None, ks=None, stride=None, id_skip=None,
|
||||
se=None, h_swish=None):
|
||||
infos = [l_type, 'input:%s' % self.repr_shape(input_shape), 'output:%s' % self.repr_shape(output_shape), ]
|
||||
|
||||
if l_type in ('expanded_conv',):
|
||||
assert None not in (mid, ks, stride, id_skip, se, h_swish)
|
||||
infos += ['expand:%d' % mid, 'kernel:%d' % ks, 'stride:%d' % stride, 'idskip:%d' % id_skip,
|
||||
'se:%d' % se, 'hs:%d' % h_swish]
|
||||
key = '-'.join(infos)
|
||||
return self.lut[key]['mean']
|
||||
|
||||
def predict_network_latency(self, net, image_size=224):
|
||||
predicted_latency = 0
|
||||
# first conv
|
||||
predicted_latency += self.query(
|
||||
'Conv', [image_size, image_size, 3],
|
||||
[(image_size + 1) // 2, (image_size + 1) // 2, net.first_conv.out_channels]
|
||||
)
|
||||
# blocks
|
||||
fsize = (image_size + 1) // 2
|
||||
for block in net.blocks:
|
||||
mb_conv = block.conv
|
||||
shortcut = block.shortcut
|
||||
|
||||
if mb_conv is None:
|
||||
continue
|
||||
if shortcut is None:
|
||||
idskip = 0
|
||||
else:
|
||||
idskip = 1
|
||||
out_fz = int((fsize - 1) / mb_conv.stride + 1)
|
||||
block_latency = self.query(
|
||||
'expanded_conv', [fsize, fsize, mb_conv.in_channels], [out_fz, out_fz, mb_conv.out_channels],
|
||||
mid=mb_conv.depth_conv.conv.in_channels, ks=mb_conv.kernel_size, stride=mb_conv.stride, id_skip=idskip,
|
||||
se=1 if mb_conv.use_se else 0, h_swish=1 if mb_conv.act_func == 'h_swish' else 0,
|
||||
)
|
||||
predicted_latency += block_latency
|
||||
fsize = out_fz
|
||||
# final expand layer
|
||||
predicted_latency += self.query(
|
||||
'Conv_1', [fsize, fsize, net.final_expand_layer.in_channels],
|
||||
[fsize, fsize, net.final_expand_layer.out_channels],
|
||||
)
|
||||
# global average pooling
|
||||
predicted_latency += self.query(
|
||||
'AvgPool2D', [fsize, fsize, net.final_expand_layer.out_channels],
|
||||
[1, 1, net.final_expand_layer.out_channels],
|
||||
)
|
||||
# feature mix layer
|
||||
predicted_latency += self.query(
|
||||
'Conv_2', [1, 1, net.feature_mix_layer.in_channels],
|
||||
[1, 1, net.feature_mix_layer.out_channels]
|
||||
)
|
||||
# classifier
|
||||
predicted_latency += self.query(
|
||||
'Logits', [1, 1, net.classifier.in_features], [net.classifier.out_features]
|
||||
)
|
||||
return predicted_latency
|
||||
|
||||
def predict_network_latency_given_config(self, net_config, image_size=224):
|
||||
predicted_latency = 0
|
||||
# first conv
|
||||
predicted_latency += self.query(
|
||||
'Conv', [image_size, image_size, 3],
|
||||
[(image_size + 1) // 2, (image_size + 1) // 2, net_config['first_conv']['out_channels']]
|
||||
)
|
||||
# blocks
|
||||
fsize = (image_size + 1) // 2
|
||||
for block in net_config['blocks']:
|
||||
mb_conv = block['mobile_inverted_conv'] if 'mobile_inverted_conv' in block else block['conv']
|
||||
shortcut = block['shortcut']
|
||||
|
||||
if mb_conv is None:
|
||||
continue
|
||||
if shortcut is None:
|
||||
idskip = 0
|
||||
else:
|
||||
idskip = 1
|
||||
out_fz = int((fsize - 1) / mb_conv['stride'] + 1)
|
||||
if mb_conv['mid_channels'] is None:
|
||||
mb_conv['mid_channels'] = round(mb_conv['in_channels'] * mb_conv['expand_ratio'])
|
||||
block_latency = self.query(
|
||||
'expanded_conv', [fsize, fsize, mb_conv['in_channels']], [out_fz, out_fz, mb_conv['out_channels']],
|
||||
mid=mb_conv['mid_channels'], ks=mb_conv['kernel_size'], stride=mb_conv['stride'], id_skip=idskip,
|
||||
se=1 if mb_conv['use_se'] else 0, h_swish=1 if mb_conv['act_func'] == 'h_swish' else 0,
|
||||
)
|
||||
predicted_latency += block_latency
|
||||
fsize = out_fz
|
||||
# final expand layer
|
||||
predicted_latency += self.query(
|
||||
'Conv_1', [fsize, fsize, net_config['final_expand_layer']['in_channels']],
|
||||
[fsize, fsize, net_config['final_expand_layer']['out_channels']],
|
||||
)
|
||||
# global average pooling
|
||||
predicted_latency += self.query(
|
||||
'AvgPool2D', [fsize, fsize, net_config['final_expand_layer']['out_channels']],
|
||||
[1, 1, net_config['final_expand_layer']['out_channels']],
|
||||
)
|
||||
# feature mix layer
|
||||
predicted_latency += self.query(
|
||||
'Conv_2', [1, 1, net_config['feature_mix_layer']['in_channels']],
|
||||
[1, 1, net_config['feature_mix_layer']['out_channels']]
|
||||
)
|
||||
# classifier
|
||||
predicted_latency += self.query(
|
||||
'Logits', [1, 1, net_config['classifier']['in_features']], [net_config['classifier']['out_features']]
|
||||
)
|
||||
return predicted_latency
|
||||
|
||||
@staticmethod
|
||||
def count_flops_given_config(net_config, image_size=224):
|
||||
flops = 0
|
||||
# first conv
|
||||
flops += count_conv_flop((image_size + 1) // 2, 3, net_config['first_conv']['out_channels'], 3, 1)
|
||||
# blocks
|
||||
fsize = (image_size + 1) // 2
|
||||
for block in net_config['blocks']:
|
||||
mb_conv = block['mobile_inverted_conv'] if 'mobile_inverted_conv' in block else block['conv']
|
||||
if mb_conv is None:
|
||||
continue
|
||||
out_fz = int((fsize - 1) / mb_conv['stride'] + 1)
|
||||
if mb_conv['mid_channels'] is None:
|
||||
mb_conv['mid_channels'] = round(mb_conv['in_channels'] * mb_conv['expand_ratio'])
|
||||
if mb_conv['expand_ratio'] != 1:
|
||||
# inverted bottleneck
|
||||
flops += count_conv_flop(fsize, mb_conv['in_channels'], mb_conv['mid_channels'], 1, 1)
|
||||
# depth conv
|
||||
flops += count_conv_flop(out_fz, mb_conv['mid_channels'], mb_conv['mid_channels'],
|
||||
mb_conv['kernel_size'], mb_conv['mid_channels'])
|
||||
if mb_conv['use_se']:
|
||||
# SE layer
|
||||
se_mid = make_divisible(mb_conv['mid_channels'] // 4, divisor=MyNetwork.CHANNEL_DIVISIBLE)
|
||||
flops += count_conv_flop(1, mb_conv['mid_channels'], se_mid, 1, 1)
|
||||
flops += count_conv_flop(1, se_mid, mb_conv['mid_channels'], 1, 1)
|
||||
# point linear
|
||||
flops += count_conv_flop(out_fz, mb_conv['mid_channels'], mb_conv['out_channels'], 1, 1)
|
||||
fsize = out_fz
|
||||
# final expand layer
|
||||
flops += count_conv_flop(fsize, net_config['final_expand_layer']['in_channels'],
|
||||
net_config['final_expand_layer']['out_channels'], 1, 1)
|
||||
# feature mix layer
|
||||
flops += count_conv_flop(1, net_config['feature_mix_layer']['in_channels'],
|
||||
net_config['feature_mix_layer']['out_channels'], 1, 1)
|
||||
# classifier
|
||||
flops += count_conv_flop(1, net_config['classifier']['in_features'],
|
||||
net_config['classifier']['out_features'], 1, 1)
|
||||
return flops / 1e6 # MFLOPs
|
||||
|
||||
|
||||
class ResNet50LatencyTable(LatencyTable):
|
||||
|
||||
def query(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def predict_network_latency(self, net, image_size):
|
||||
raise NotImplementedError
|
||||
|
||||
def predict_network_latency_given_config(self, net_config, image_size):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def count_flops_given_config(net_config, image_size=224):
|
||||
flops = 0
|
||||
# input stem
|
||||
for layer_config in net_config['input_stem']:
|
||||
if layer_config['name'] != 'ConvLayer':
|
||||
layer_config = layer_config['conv']
|
||||
in_channel = layer_config['in_channels']
|
||||
out_channel = layer_config['out_channels']
|
||||
out_image_size = int((image_size - 1) / layer_config['stride'] + 1)
|
||||
|
||||
flops += count_conv_flop(out_image_size, in_channel, out_channel,
|
||||
layer_config['kernel_size'], layer_config.get('groups', 1))
|
||||
image_size = out_image_size
|
||||
# max pooling
|
||||
image_size = int((image_size - 1) / 2 + 1)
|
||||
# ResNetBottleneckBlocks
|
||||
for block_config in net_config['blocks']:
|
||||
in_channel = block_config['in_channels']
|
||||
out_channel = block_config['out_channels']
|
||||
|
||||
out_image_size = int((image_size - 1) / block_config['stride'] + 1)
|
||||
mid_channel = block_config['mid_channels'] if block_config['mid_channels'] is not None \
|
||||
else round(out_channel * block_config['expand_ratio'])
|
||||
mid_channel = make_divisible(mid_channel, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
# conv1
|
||||
flops += count_conv_flop(image_size, in_channel, mid_channel, 1, 1)
|
||||
# conv2
|
||||
flops += count_conv_flop(out_image_size, mid_channel, mid_channel,
|
||||
block_config['kernel_size'], block_config['groups'])
|
||||
# conv3
|
||||
flops += count_conv_flop(out_image_size, mid_channel, out_channel, 1, 1)
|
||||
# downsample
|
||||
if block_config['stride'] == 1 and in_channel == out_channel:
|
||||
pass
|
||||
else:
|
||||
flops += count_conv_flop(out_image_size, in_channel, out_channel, 1, 1)
|
||||
image_size = out_image_size
|
||||
# final classifier
|
||||
flops += count_conv_flop(1, net_config['classifier']['in_features'],
|
||||
net_config['classifier']['out_features'], 1, 1)
|
||||
return flops / 1e6 # MFLOPs
|
||||
@@ -0,0 +1,5 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from .evolution import *
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user