469 lines
14 KiB
Python
469 lines
14 KiB
Python
from __future__ import print_function
|
|
import torch
|
|
import os
|
|
import numpy as np
|
|
from collections import defaultdict
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from analysis.arch_functions import decode_x_to_NAS_BENCH_201_matrix, decode_x_to_NAS_BENCH_201_string
|
|
from all_path import *
|
|
|
|
|
|
def get_data_scaler(config):
|
|
"""Data normalizer. Assume data are always in [0, 1]."""
|
|
|
|
if config.data.centered:
|
|
# Rescale to [-1, 1]
|
|
return lambda x: x * 2. - 1.
|
|
else:
|
|
return lambda x: x
|
|
|
|
|
|
def get_data_inverse_scaler(config):
|
|
"""Inverse data normalizer."""
|
|
|
|
if config.data.centered:
|
|
# Rescale [-1, 1] to [0, 1]
|
|
return lambda x: (x + 1.) / 2.
|
|
else:
|
|
return lambda x: x
|
|
|
|
|
|
def is_triu(mat):
|
|
is_triu_ = np.allclose(mat, np.triu(mat))
|
|
return is_triu_
|
|
|
|
|
|
def get_dataset(config):
|
|
train_dataset = NASBench201Dataset(
|
|
data_path=NASBENCH201_INFO,
|
|
mode='train')
|
|
|
|
eval_dataset = NASBench201Dataset(
|
|
data_path=NASBENCH201_INFO,
|
|
mode='eval')
|
|
|
|
test_dataset = NASBench201Dataset(
|
|
data_path=NASBENCH201_INFO,
|
|
mode='test')
|
|
|
|
return train_dataset, eval_dataset, test_dataset
|
|
|
|
|
|
def get_dataloader(config, train_dataset, eval_dataset, test_dataset):
|
|
train_loader = DataLoader(dataset=train_dataset,
|
|
batch_size=config.training.batch_size,
|
|
shuffle=True,
|
|
collate_fn=None)
|
|
|
|
eval_loader = DataLoader(dataset=eval_dataset,
|
|
batch_size=config.training.batch_size,
|
|
shuffle=False,
|
|
collate_fn=None)
|
|
|
|
test_loader = DataLoader(dataset=test_dataset,
|
|
batch_size=config.training.batch_size,
|
|
shuffle=False,
|
|
collate_fn=None)
|
|
|
|
return train_loader, eval_loader, test_loader
|
|
|
|
|
|
class NASBench201Dataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
data_path,
|
|
split_ratio=1.0,
|
|
mode='train',
|
|
label_list=None,
|
|
tg_dataset=None):
|
|
|
|
self.ops_decoder = ['input', 'output', 'none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
|
|
|
# ---------- entire dataset ---------- #
|
|
self.data = torch.load(data_path)
|
|
# ---------- igraph ---------- #
|
|
self.igraph_list = self.data['g']
|
|
# ---------- x ---------- #
|
|
self.x_list = self.data['x']
|
|
# ---------- adj ---------- #
|
|
adj = self.get_adj()
|
|
self.adj_list = [adj] * len(self.igraph_list)
|
|
# ---------- matrix ---------- #
|
|
self.matrix_list = self.data['matrix']
|
|
# ---------- arch_str ---------- #
|
|
self.arch_str_list = self.data['str']
|
|
# ---------- labels ---------- #
|
|
self.label_list = label_list
|
|
if self.label_list is not None:
|
|
self.val_acc_list = self.data['val-acc'][tg_dataset]
|
|
self.test_acc_list = self.data['test-acc'][tg_dataset]
|
|
self.flops_list = self.data['flops'][tg_dataset]
|
|
self.params_list = self.data['params'][tg_dataset]
|
|
self.latency_list = self.data['latency'][tg_dataset]
|
|
|
|
# ----------- split dataset ---------- #
|
|
self.ds_idx = list(torch.load(DATA_PATH + '/ridx.pt'))
|
|
self.split_ratio = split_ratio
|
|
num_train = int(len(self.x_list) * self.split_ratio)
|
|
num_test = len(self.x_list) - num_train
|
|
|
|
# ----------- compute mean and std w/ training dataset ---------- #
|
|
if self.label_list is not None:
|
|
self.train_idx_list = self.ds_idx[:num_train]
|
|
print('>>> Computing mean and std of the training set...')
|
|
LABEL_TO_MEAN_STD = defaultdict(dict)
|
|
assert type(self.label_list) == list, f"self.label_list is {type(self.label_list)}"
|
|
for label in self.label_list:
|
|
if label == 'val-acc':
|
|
self.val_acc_list_tr = [self.val_acc_list[i] for i in self.train_idx_list]
|
|
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.val_acc_list_tr))
|
|
elif label == 'test-acc':
|
|
self.test_acc_list_tr = [self.test_acc_list[i] for i in self.train_idx_list]
|
|
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.test_acc_list_tr))
|
|
elif label == 'flops':
|
|
self.flops_list_tr = [self.flops_list[i] for i in self.train_idx_list]
|
|
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.flops_list_tr))
|
|
elif label == 'params':
|
|
self.params_list_tr = [self.params_list[i] for i in self.train_idx_list]
|
|
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.params_list_tr))
|
|
elif label == 'latency':
|
|
self.latency_list_tr = [self.latency_list[i] for i in self.train_idx_list]
|
|
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.latency_list_tr))
|
|
else:
|
|
raise ValueError
|
|
|
|
self.mode = mode
|
|
if self.mode in ['train']:
|
|
self.idx_list = self.ds_idx[:num_train]
|
|
elif self.mode in ['eval']:
|
|
if num_test == 0:
|
|
self.idx_list = self.ds_idx[:100]
|
|
else:
|
|
self.idx_list = self.ds_idx[:num_test]
|
|
elif self.mode in ['test']:
|
|
if num_test == 0:
|
|
self.idx_list = self.ds_idx[15000:]
|
|
else:
|
|
self.idx_list = self.ds_idx[num_train:]
|
|
|
|
self.igraph_list_ = [self.igraph_list[i] for i in self.idx_list]
|
|
self.x_list_ = [self.x_list[i] for i in self.idx_list]
|
|
self.adj_list_ = [self.adj_list[i] for i in self.idx_list]
|
|
self.matrix_list_ = [self.matrix_list[i] for i in self.idx_list]
|
|
self.arch_str_list_ = [self.arch_str_list[i] for i in self.idx_list]
|
|
|
|
if self.label_list is not None:
|
|
assert type(self.label_list) == list
|
|
for label in self.label_list:
|
|
if label == 'val-acc':
|
|
self.val_acc_list_ = [self.val_acc_list[i] for i in self.idx_list]
|
|
self.val_acc_list_ = self.normalize(self.val_acc_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
|
|
elif label == 'test-acc':
|
|
self.test_acc_list_ = [self.test_acc_list[i] for i in self.idx_list]
|
|
self.test_acc_list_ = self.normalize(self.test_acc_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
|
|
elif label == 'flops':
|
|
self.flops_list_ = [self.flops_list[i] for i in self.idx_list]
|
|
self.flops_list_ = self.normalize(self.flops_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
|
|
elif label == 'params':
|
|
self.params_list_ = [self.params_list[i] for i in self.idx_list]
|
|
self.params_list_ = self.normalize(self.params_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
|
|
elif label == 'latency':
|
|
self.latency_list_ = [self.latency_list[i] for i in self.idx_list]
|
|
self.latency_list_ = self.normalize(self.latency_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
|
|
else:
|
|
raise ValueError
|
|
|
|
|
|
def normalize(self, original, mean, std):
|
|
return [(i-mean)/std for i in original]
|
|
|
|
|
|
# def get_not_connect_prev_adj(self):
|
|
def get_adj(self):
|
|
adj = np.asarray(
|
|
[[0, 1, 1, 1, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 1, 1, 0, 0],
|
|
[0, 0, 0, 0, 0, 0, 1, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
[0, 0, 0, 0, 0, 0, 1, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
[0, 0, 0, 0, 0, 0, 0, 0]]
|
|
)
|
|
adj = torch.tensor(adj, dtype=torch.float32, device=torch.device('cpu'))
|
|
return adj
|
|
|
|
|
|
@property
|
|
def adj(self):
|
|
return self.adj_list_[0]
|
|
|
|
|
|
def mask(self, algo='floyd'):
|
|
from utils import aug_mask
|
|
return aug_mask(self.adj, algo=algo)[0]
|
|
|
|
|
|
def get_unnoramlized_entire_data(self, label, tg_dataset):
|
|
entire_val_acc_list = self.data['val-acc'][tg_dataset]
|
|
entire_test_acc_list = self.data['test-acc'][tg_dataset]
|
|
entire_flops_list = self.data['flops'][tg_dataset]
|
|
entire_params_list = self.data['params'][tg_dataset]
|
|
entire_latency_list = self.data['latency'][tg_dataset]
|
|
|
|
if label == 'val-acc':
|
|
return entire_val_acc_list
|
|
elif label == 'test-acc':
|
|
return entire_test_acc_list
|
|
elif label == 'flops':
|
|
return entire_flops_list
|
|
elif label == 'params':
|
|
return entire_params_list
|
|
elif label == 'latency':
|
|
return entire_latency_list
|
|
else:
|
|
raise ValueError
|
|
|
|
|
|
def get_unnoramlized_data(self, label, tg_dataset):
|
|
entire_val_acc_list = self.data['val-acc'][tg_dataset]
|
|
entire_test_acc_list = self.data['test-acc'][tg_dataset]
|
|
entire_flops_list = self.data['flops'][tg_dataset]
|
|
entire_params_list = self.data['params'][tg_dataset]
|
|
entire_latency_list = self.data['latency'][tg_dataset]
|
|
|
|
if label == 'val-acc':
|
|
return [entire_val_acc_list[i] for i in self.idx_list]
|
|
elif label == 'test-acc':
|
|
return [entire_test_acc_list[i] for i in self.idx_list]
|
|
elif label == 'flops':
|
|
return [entire_flops_list[i] for i in self.idx_list]
|
|
elif label == 'params':
|
|
return [entire_params_list[i] for i in self.idx_list]
|
|
elif label == 'latency':
|
|
return [entire_latency_list[i] for i in self.idx_list]
|
|
else:
|
|
raise ValueError
|
|
|
|
|
|
def __len__(self):
|
|
return len(self.x_list_)
|
|
|
|
|
|
def __getitem__(self, index):
|
|
label_dict = {}
|
|
if self.label_list is not None:
|
|
assert type(self.label_list) == list
|
|
for label in self.label_list:
|
|
if label == 'val-acc':
|
|
label_dict[f"{label}"] = self.val_acc_list_[index]
|
|
elif label == 'test-acc':
|
|
label_dict[f"{label}"] = self.test_acc_list_[index]
|
|
elif label == 'flops':
|
|
label_dict[f"{label}"] = self.flops_list_[index]
|
|
elif label == 'params':
|
|
label_dict[f"{label}"] = self.params_list_[index]
|
|
elif label == 'latency':
|
|
label_dict[f"{label}"] = self.latency_list_[index]
|
|
else:
|
|
raise ValueError
|
|
return self.x_list_[index], self.adj_list_[index], label_dict
|
|
|
|
|
|
# ---------- Meta-Dataset ---------- #
|
|
def get_meta_dataset(config):
|
|
train_dataset = MetaTrainDatabase(
|
|
data_path=DATA_PATH,
|
|
num_sample=config.model.num_sample,
|
|
label_list=config.data.label_list,
|
|
mode='train')
|
|
|
|
eval_dataset = MetaTrainDatabase(
|
|
data_path=DATA_PATH,
|
|
num_sample=config.model.num_sample,
|
|
label_list=config.data.label_list,
|
|
mode='eval')
|
|
|
|
test_dataset = None
|
|
|
|
return train_dataset, eval_dataset, test_dataset
|
|
|
|
|
|
def get_meta_dataloader(config ,train_dataset, eval_dataset, test_dataset):
|
|
train_loader = DataLoader(dataset=train_dataset,
|
|
batch_size=config.training.batch_size,
|
|
shuffle=True)
|
|
|
|
eval_loader = DataLoader(dataset=eval_dataset,
|
|
batch_size=config.training.batch_size,
|
|
shuffle=False)
|
|
|
|
test_loader = None
|
|
|
|
return train_loader, eval_loader, test_loader
|
|
|
|
|
|
class MetaTrainDatabase(Dataset):
|
|
def __init__(
|
|
self,
|
|
data_path,
|
|
num_sample,
|
|
label_list,
|
|
mode='train'):
|
|
|
|
self.ops_decoder = ['input', 'output', 'none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
|
|
|
self.mode = mode
|
|
self.acc_norm = True
|
|
self.num_sample = num_sample
|
|
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
|
|
|
|
mtr_data_path = os.path.join(data_path, 'meta_train_tasks_predictor.pt')
|
|
idx_path = os.path.join(data_path, 'meta_train_tasks_predictor_idx.pt')
|
|
data = torch.load(mtr_data_path)
|
|
|
|
self.acc_list = data['acc']
|
|
self.task = data['task']
|
|
|
|
# ---------- igraph ---------- #
|
|
self.igraph_list = data['g']
|
|
# ---------- x ---------- #
|
|
self.x_list = data['x']
|
|
# ---------- adj ---------- #
|
|
adj = self.get_adj()
|
|
self.adj_list = [adj] * len(self.igraph_list)
|
|
# ---------- matrix ----------- #
|
|
if 'matrix' in data:
|
|
self.matrix_list = data['matrix']
|
|
else:
|
|
self.matrix_list = [decode_x_to_NAS_BENCH_201_matrix(i) for i in self.x_list]
|
|
# ---------- arch_str ---------- #
|
|
if 'str' in data:
|
|
self.arch_str_list = data['str']
|
|
else:
|
|
self.arch_str_list = [decode_x_to_NAS_BENCH_201_string(i, self.ops_decoder) for i in self.x_list]
|
|
# ---------- label ---------- #
|
|
self.label_list = label_list
|
|
if self.label_list is not None:
|
|
self.flops_list = torch.tensor(data['flops'])
|
|
self.params_list = torch.tensor(data['params'])
|
|
self.latency_list = torch.tensor(data['latency'])
|
|
|
|
random_idx_lst = torch.load(idx_path)
|
|
self.idx_lst = {}
|
|
self.idx_lst['eval'] = random_idx_lst[:400]
|
|
self.idx_lst['train'] = random_idx_lst[400:]
|
|
self.acc_list = torch.tensor(self.acc_list)
|
|
self.mean = torch.mean(self.acc_list[self.idx_lst['train']]).item()
|
|
self.std = torch.std(self.acc_list[self.idx_lst['train']]).item()
|
|
self.task_lst = torch.load(os.path.join(data_path, 'meta_train_task_lst.pt'))
|
|
|
|
|
|
def get_adj(self):
|
|
adj = np.asarray(
|
|
[[0, 1, 1, 1, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 1, 1, 0, 0],
|
|
[0, 0, 0, 0, 0, 0, 1, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
[0, 0, 0, 0, 0, 0, 1, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
[0, 0, 0, 0, 0, 0, 0, 1],
|
|
[0, 0, 0, 0, 0, 0, 0, 0]]
|
|
)
|
|
adj = torch.tensor(adj, dtype=torch.float32, device=torch.device('cpu'))
|
|
return adj
|
|
|
|
|
|
@property
|
|
def adj(self):
|
|
return self.adj_list[0]
|
|
|
|
|
|
def mask(self, algo='floyd'):
|
|
from utils import aug_mask
|
|
return aug_mask(self.adj, algo=algo)[0]
|
|
|
|
|
|
def set_mode(self, mode):
|
|
self.mode = mode
|
|
|
|
|
|
def __len__(self):
|
|
return len(self.idx_lst[self.mode])
|
|
|
|
|
|
def __getitem__(self, index):
|
|
data = []
|
|
ridx = self.idx_lst[self.mode]
|
|
tidx = self.task[ridx[index]]
|
|
classes = self.task_lst[tidx]
|
|
|
|
# ---------- igraph -----------
|
|
graph = self.igraph_list[ridx[index]]
|
|
# ---------- x -----------
|
|
x = self.x_list[ridx[index]]
|
|
# ---------- adj ----------
|
|
adj = self.adj_list[ridx[index]]
|
|
|
|
acc = self.acc_list[ridx[index]]
|
|
for cls in classes:
|
|
cx = self.x[cls-1][0]
|
|
ridx = torch.randperm(len(cx))
|
|
data.append(cx[ridx[:self.num_sample]])
|
|
task = torch.cat(data)
|
|
if self.acc_norm:
|
|
acc = ((acc- self.mean) / self.std) / 100.0
|
|
else:
|
|
acc = acc / 100.0
|
|
|
|
label_dict = {}
|
|
if self.label_list is not None:
|
|
assert type(self.label_list) == list
|
|
for label in self.label_list:
|
|
if label == 'meta-acc':
|
|
label_dict[f"{label}"] = acc
|
|
elif label == 'flops':
|
|
label_dict[f"{label}"] = self.flops_list[ridx[index]]
|
|
elif label == 'params':
|
|
label_dict[f"{label}"] = self.params_list[ridx[index]]
|
|
elif label == 'latency':
|
|
label_dict[f"{label}"] = self.latency_list[ridx[index]]
|
|
else:
|
|
raise ValueError
|
|
|
|
return x, adj, label_dict, task
|
|
|
|
|
|
class MetaTestDataset(Dataset):
|
|
def __init__(self, data_path, data_name, num_sample, num_class=None):
|
|
self.num_sample = num_sample
|
|
self.data_name = data_name
|
|
|
|
num_class_dict = {
|
|
'cifar100': 100,
|
|
'cifar10': 10,
|
|
'aircraft': 30,
|
|
'pets': 37
|
|
}
|
|
|
|
if num_class is not None:
|
|
self.num_class = num_class
|
|
else:
|
|
self.num_class = num_class_dict[data_name]
|
|
|
|
self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt'))
|
|
|
|
|
|
def __len__(self):
|
|
return 1000000
|
|
|
|
|
|
def __getitem__(self, index):
|
|
data = []
|
|
classes = list(range(self.num_class))
|
|
for cls in classes:
|
|
cx = self.x[cls][0]
|
|
ridx = torch.randperm(len(cx))
|
|
data.append(cx[ridx[:self.num_sample]])
|
|
x = torch.cat(data)
|
|
return x |