diffusionNAG/NAS-Bench-201/datasets_nas.py
2024-03-15 14:38:51 +00:00

469 lines
14 KiB
Python

from __future__ import print_function
import torch
import os
import numpy as np
from collections import defaultdict
from torch.utils.data import DataLoader, Dataset
from analysis.arch_functions import decode_x_to_NAS_BENCH_201_matrix, decode_x_to_NAS_BENCH_201_string
from all_path import *
def get_data_scaler(config):
"""Data normalizer. Assume data are always in [0, 1]."""
if config.data.centered:
# Rescale to [-1, 1]
return lambda x: x * 2. - 1.
else:
return lambda x: x
def get_data_inverse_scaler(config):
"""Inverse data normalizer."""
if config.data.centered:
# Rescale [-1, 1] to [0, 1]
return lambda x: (x + 1.) / 2.
else:
return lambda x: x
def is_triu(mat):
is_triu_ = np.allclose(mat, np.triu(mat))
return is_triu_
def get_dataset(config):
train_dataset = NASBench201Dataset(
data_path=NASBENCH201_INFO,
mode='train')
eval_dataset = NASBench201Dataset(
data_path=NASBENCH201_INFO,
mode='eval')
test_dataset = NASBench201Dataset(
data_path=NASBENCH201_INFO,
mode='test')
return train_dataset, eval_dataset, test_dataset
def get_dataloader(config, train_dataset, eval_dataset, test_dataset):
train_loader = DataLoader(dataset=train_dataset,
batch_size=config.training.batch_size,
shuffle=True,
collate_fn=None)
eval_loader = DataLoader(dataset=eval_dataset,
batch_size=config.training.batch_size,
shuffle=False,
collate_fn=None)
test_loader = DataLoader(dataset=test_dataset,
batch_size=config.training.batch_size,
shuffle=False,
collate_fn=None)
return train_loader, eval_loader, test_loader
class NASBench201Dataset(Dataset):
def __init__(
self,
data_path,
split_ratio=1.0,
mode='train',
label_list=None,
tg_dataset=None):
self.ops_decoder = ['input', 'output', 'none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
# ---------- entire dataset ---------- #
self.data = torch.load(data_path)
# ---------- igraph ---------- #
self.igraph_list = self.data['g']
# ---------- x ---------- #
self.x_list = self.data['x']
# ---------- adj ---------- #
adj = self.get_adj()
self.adj_list = [adj] * len(self.igraph_list)
# ---------- matrix ---------- #
self.matrix_list = self.data['matrix']
# ---------- arch_str ---------- #
self.arch_str_list = self.data['str']
# ---------- labels ---------- #
self.label_list = label_list
if self.label_list is not None:
self.val_acc_list = self.data['val-acc'][tg_dataset]
self.test_acc_list = self.data['test-acc'][tg_dataset]
self.flops_list = self.data['flops'][tg_dataset]
self.params_list = self.data['params'][tg_dataset]
self.latency_list = self.data['latency'][tg_dataset]
# ----------- split dataset ---------- #
self.ds_idx = list(torch.load(DATA_PATH + '/ridx.pt'))
self.split_ratio = split_ratio
num_train = int(len(self.x_list) * self.split_ratio)
num_test = len(self.x_list) - num_train
# ----------- compute mean and std w/ training dataset ---------- #
if self.label_list is not None:
self.train_idx_list = self.ds_idx[:num_train]
print('>>> Computing mean and std of the training set...')
LABEL_TO_MEAN_STD = defaultdict(dict)
assert type(self.label_list) == list, f"self.label_list is {type(self.label_list)}"
for label in self.label_list:
if label == 'val-acc':
self.val_acc_list_tr = [self.val_acc_list[i] for i in self.train_idx_list]
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.val_acc_list_tr))
elif label == 'test-acc':
self.test_acc_list_tr = [self.test_acc_list[i] for i in self.train_idx_list]
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.test_acc_list_tr))
elif label == 'flops':
self.flops_list_tr = [self.flops_list[i] for i in self.train_idx_list]
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.flops_list_tr))
elif label == 'params':
self.params_list_tr = [self.params_list[i] for i in self.train_idx_list]
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.params_list_tr))
elif label == 'latency':
self.latency_list_tr = [self.latency_list[i] for i in self.train_idx_list]
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.latency_list_tr))
else:
raise ValueError
self.mode = mode
if self.mode in ['train']:
self.idx_list = self.ds_idx[:num_train]
elif self.mode in ['eval']:
if num_test == 0:
self.idx_list = self.ds_idx[:100]
else:
self.idx_list = self.ds_idx[:num_test]
elif self.mode in ['test']:
if num_test == 0:
self.idx_list = self.ds_idx[15000:]
else:
self.idx_list = self.ds_idx[num_train:]
self.igraph_list_ = [self.igraph_list[i] for i in self.idx_list]
self.x_list_ = [self.x_list[i] for i in self.idx_list]
self.adj_list_ = [self.adj_list[i] for i in self.idx_list]
self.matrix_list_ = [self.matrix_list[i] for i in self.idx_list]
self.arch_str_list_ = [self.arch_str_list[i] for i in self.idx_list]
if self.label_list is not None:
assert type(self.label_list) == list
for label in self.label_list:
if label == 'val-acc':
self.val_acc_list_ = [self.val_acc_list[i] for i in self.idx_list]
self.val_acc_list_ = self.normalize(self.val_acc_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
elif label == 'test-acc':
self.test_acc_list_ = [self.test_acc_list[i] for i in self.idx_list]
self.test_acc_list_ = self.normalize(self.test_acc_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
elif label == 'flops':
self.flops_list_ = [self.flops_list[i] for i in self.idx_list]
self.flops_list_ = self.normalize(self.flops_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
elif label == 'params':
self.params_list_ = [self.params_list[i] for i in self.idx_list]
self.params_list_ = self.normalize(self.params_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
elif label == 'latency':
self.latency_list_ = [self.latency_list[i] for i in self.idx_list]
self.latency_list_ = self.normalize(self.latency_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
else:
raise ValueError
def normalize(self, original, mean, std):
return [(i-mean)/std for i in original]
# def get_not_connect_prev_adj(self):
def get_adj(self):
adj = np.asarray(
[[0, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0]]
)
adj = torch.tensor(adj, dtype=torch.float32, device=torch.device('cpu'))
return adj
@property
def adj(self):
return self.adj_list_[0]
def mask(self, algo='floyd'):
from utils import aug_mask
return aug_mask(self.adj, algo=algo)[0]
def get_unnoramlized_entire_data(self, label, tg_dataset):
entire_val_acc_list = self.data['val-acc'][tg_dataset]
entire_test_acc_list = self.data['test-acc'][tg_dataset]
entire_flops_list = self.data['flops'][tg_dataset]
entire_params_list = self.data['params'][tg_dataset]
entire_latency_list = self.data['latency'][tg_dataset]
if label == 'val-acc':
return entire_val_acc_list
elif label == 'test-acc':
return entire_test_acc_list
elif label == 'flops':
return entire_flops_list
elif label == 'params':
return entire_params_list
elif label == 'latency':
return entire_latency_list
else:
raise ValueError
def get_unnoramlized_data(self, label, tg_dataset):
entire_val_acc_list = self.data['val-acc'][tg_dataset]
entire_test_acc_list = self.data['test-acc'][tg_dataset]
entire_flops_list = self.data['flops'][tg_dataset]
entire_params_list = self.data['params'][tg_dataset]
entire_latency_list = self.data['latency'][tg_dataset]
if label == 'val-acc':
return [entire_val_acc_list[i] for i in self.idx_list]
elif label == 'test-acc':
return [entire_test_acc_list[i] for i in self.idx_list]
elif label == 'flops':
return [entire_flops_list[i] for i in self.idx_list]
elif label == 'params':
return [entire_params_list[i] for i in self.idx_list]
elif label == 'latency':
return [entire_latency_list[i] for i in self.idx_list]
else:
raise ValueError
def __len__(self):
return len(self.x_list_)
def __getitem__(self, index):
label_dict = {}
if self.label_list is not None:
assert type(self.label_list) == list
for label in self.label_list:
if label == 'val-acc':
label_dict[f"{label}"] = self.val_acc_list_[index]
elif label == 'test-acc':
label_dict[f"{label}"] = self.test_acc_list_[index]
elif label == 'flops':
label_dict[f"{label}"] = self.flops_list_[index]
elif label == 'params':
label_dict[f"{label}"] = self.params_list_[index]
elif label == 'latency':
label_dict[f"{label}"] = self.latency_list_[index]
else:
raise ValueError
return self.x_list_[index], self.adj_list_[index], label_dict
# ---------- Meta-Dataset ---------- #
def get_meta_dataset(config):
train_dataset = MetaTrainDatabase(
data_path=DATA_PATH,
num_sample=config.model.num_sample,
label_list=config.data.label_list,
mode='train')
eval_dataset = MetaTrainDatabase(
data_path=DATA_PATH,
num_sample=config.model.num_sample,
label_list=config.data.label_list,
mode='eval')
test_dataset = None
return train_dataset, eval_dataset, test_dataset
def get_meta_dataloader(config ,train_dataset, eval_dataset, test_dataset):
train_loader = DataLoader(dataset=train_dataset,
batch_size=config.training.batch_size,
shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset,
batch_size=config.training.batch_size,
shuffle=False)
test_loader = None
return train_loader, eval_loader, test_loader
class MetaTrainDatabase(Dataset):
def __init__(
self,
data_path,
num_sample,
label_list,
mode='train'):
self.ops_decoder = ['input', 'output', 'none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
self.mode = mode
self.acc_norm = True
self.num_sample = num_sample
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
mtr_data_path = os.path.join(data_path, 'meta_train_tasks_predictor.pt')
idx_path = os.path.join(data_path, 'meta_train_tasks_predictor_idx.pt')
data = torch.load(mtr_data_path)
self.acc_list = data['acc']
self.task = data['task']
# ---------- igraph ---------- #
self.igraph_list = data['g']
# ---------- x ---------- #
self.x_list = data['x']
# ---------- adj ---------- #
adj = self.get_adj()
self.adj_list = [adj] * len(self.igraph_list)
# ---------- matrix ----------- #
if 'matrix' in data:
self.matrix_list = data['matrix']
else:
self.matrix_list = [decode_x_to_NAS_BENCH_201_matrix(i) for i in self.x_list]
# ---------- arch_str ---------- #
if 'str' in data:
self.arch_str_list = data['str']
else:
self.arch_str_list = [decode_x_to_NAS_BENCH_201_string(i, self.ops_decoder) for i in self.x_list]
# ---------- label ---------- #
self.label_list = label_list
if self.label_list is not None:
self.flops_list = torch.tensor(data['flops'])
self.params_list = torch.tensor(data['params'])
self.latency_list = torch.tensor(data['latency'])
random_idx_lst = torch.load(idx_path)
self.idx_lst = {}
self.idx_lst['eval'] = random_idx_lst[:400]
self.idx_lst['train'] = random_idx_lst[400:]
self.acc_list = torch.tensor(self.acc_list)
self.mean = torch.mean(self.acc_list[self.idx_lst['train']]).item()
self.std = torch.std(self.acc_list[self.idx_lst['train']]).item()
self.task_lst = torch.load(os.path.join(data_path, 'meta_train_task_lst.pt'))
def get_adj(self):
adj = np.asarray(
[[0, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 0]]
)
adj = torch.tensor(adj, dtype=torch.float32, device=torch.device('cpu'))
return adj
@property
def adj(self):
return self.adj_list[0]
def mask(self, algo='floyd'):
from utils import aug_mask
return aug_mask(self.adj, algo=algo)[0]
def set_mode(self, mode):
self.mode = mode
def __len__(self):
return len(self.idx_lst[self.mode])
def __getitem__(self, index):
data = []
ridx = self.idx_lst[self.mode]
tidx = self.task[ridx[index]]
classes = self.task_lst[tidx]
# ---------- igraph -----------
graph = self.igraph_list[ridx[index]]
# ---------- x -----------
x = self.x_list[ridx[index]]
# ---------- adj ----------
adj = self.adj_list[ridx[index]]
acc = self.acc_list[ridx[index]]
for cls in classes:
cx = self.x[cls-1][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
task = torch.cat(data)
if self.acc_norm:
acc = ((acc- self.mean) / self.std) / 100.0
else:
acc = acc / 100.0
label_dict = {}
if self.label_list is not None:
assert type(self.label_list) == list
for label in self.label_list:
if label == 'meta-acc':
label_dict[f"{label}"] = acc
elif label == 'flops':
label_dict[f"{label}"] = self.flops_list[ridx[index]]
elif label == 'params':
label_dict[f"{label}"] = self.params_list[ridx[index]]
elif label == 'latency':
label_dict[f"{label}"] = self.latency_list[ridx[index]]
else:
raise ValueError
return x, adj, label_dict, task
class MetaTestDataset(Dataset):
def __init__(self, data_path, data_name, num_sample, num_class=None):
self.num_sample = num_sample
self.data_name = data_name
num_class_dict = {
'cifar100': 100,
'cifar10': 10,
'aircraft': 30,
'pets': 37
}
if num_class is not None:
self.num_class = num_class
else:
self.num_class = num_class_dict[data_name]
self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt'))
def __len__(self):
return 1000000
def __getitem__(self, index):
data = []
classes = list(range(self.num_class))
for cls in classes:
cx = self.x[cls][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
x = torch.cat(data)
return x