493 lines
18 KiB
Python
493 lines
18 KiB
Python
|
from __future__ import print_function
|
||
|
import torch
|
||
|
import os
|
||
|
import numpy as np
|
||
|
from torch.utils.data import DataLoader, Dataset
|
||
|
|
||
|
from torch_geometric.utils import to_networkx
|
||
|
|
||
|
from analysis.arch_functions import get_x_adj_from_opsdict_ofa, get_string_from_onehot_x
|
||
|
from all_path import PROCESSED_DATA_PATH, SCORE_MODEL_DATA_IDX_PATH
|
||
|
from analysis.arch_functions import OPS
|
||
|
|
||
|
|
||
|
def get_data_scaler(config):
|
||
|
"""Data normalizer. Assume data are always in [0, 1]."""
|
||
|
|
||
|
if config.data.centered:
|
||
|
# Rescale to [-1, 1]
|
||
|
return lambda x: x * 2. - 1.
|
||
|
else:
|
||
|
return lambda x: x
|
||
|
|
||
|
|
||
|
def get_data_inverse_scaler(config):
|
||
|
"""Inverse data normalizer."""
|
||
|
|
||
|
if config.data.centered:
|
||
|
# Rescale [-1, 1] to [0, 1]
|
||
|
return lambda x: (x + 1.) / 2.
|
||
|
else:
|
||
|
return lambda x: x
|
||
|
|
||
|
|
||
|
def networkx_graphs(dataset):
|
||
|
return [to_networkx(dataset[i], to_undirected=False, remove_self_loops=True) for i in range(len(dataset))]
|
||
|
|
||
|
|
||
|
def get_dataloader(config, train_dataset, eval_dataset, test_dataset):
|
||
|
train_loader = DataLoader(dataset=train_dataset,
|
||
|
batch_size=config.training.batch_size,
|
||
|
shuffle=True,
|
||
|
collate_fn=collate_fn_ofa if config.model_type == 'meta_predictor' else None)
|
||
|
eval_loader = DataLoader(dataset=eval_dataset,
|
||
|
batch_size=config.training.batch_size,
|
||
|
shuffle=False,
|
||
|
collate_fn=collate_fn_ofa if config.model_type == 'meta_predictor' else None)
|
||
|
test_loader = DataLoader(dataset=test_dataset,
|
||
|
batch_size=config.training.batch_size,
|
||
|
shuffle=False,
|
||
|
collate_fn=collate_fn_ofa if config.model_type == 'meta_predictor' else None)
|
||
|
|
||
|
return train_loader, eval_loader, test_loader
|
||
|
|
||
|
|
||
|
def get_dataloader_iter(config, train_dataset, eval_dataset, test_dataset):
|
||
|
|
||
|
train_loader = DataLoader(dataset=train_dataset,
|
||
|
batch_size=config.training.batch_size if len(train_dataset) > config.training.batch_size else len(train_dataset),
|
||
|
# batch_size=8,
|
||
|
shuffle=True,)
|
||
|
eval_loader = DataLoader(dataset=eval_dataset,
|
||
|
batch_size=config.training.batch_size if len(eval_dataset) > config.training.batch_size else len(eval_dataset),
|
||
|
# batch_size=8,
|
||
|
shuffle=False,)
|
||
|
test_loader = DataLoader(dataset=test_dataset,
|
||
|
batch_size=config.training.batch_size if len(test_dataset) > config.training.batch_size else len(test_dataset),
|
||
|
# batch_size=8,
|
||
|
shuffle=False,)
|
||
|
|
||
|
return train_loader, eval_loader, test_loader
|
||
|
|
||
|
|
||
|
def is_triu(mat):
|
||
|
is_triu_ = np.allclose(mat, np.triu(mat))
|
||
|
return is_triu_
|
||
|
|
||
|
|
||
|
def collate_fn_ofa(batch):
|
||
|
# x, adj, label_dict, task
|
||
|
x = torch.stack([item[0] for item in batch])
|
||
|
adj = torch.stack([item[1] for item in batch])
|
||
|
label_dict = {}
|
||
|
for item in batch:
|
||
|
for k, v in item[2].items():
|
||
|
if not k in label_dict.keys():
|
||
|
label_dict[k] = []
|
||
|
label_dict[k].append(v)
|
||
|
for k, v in label_dict.items():
|
||
|
label_dict[k] = torch.tensor(v)
|
||
|
task = [item[3] for item in batch]
|
||
|
return x, adj, label_dict, task
|
||
|
|
||
|
|
||
|
def get_dataset(config):
|
||
|
"""Create data loaders for training and evaluation.
|
||
|
|
||
|
Args:
|
||
|
config: A ml_collection.ConfigDict parsed from config files.
|
||
|
|
||
|
Returns:
|
||
|
train_ds, eval_ds, test_ds
|
||
|
"""
|
||
|
num_train = config.data.num_train if 'num_train' in config.data else None
|
||
|
NASDataset = OFADataset
|
||
|
|
||
|
train_dataset = NASDataset(
|
||
|
config.data.root,
|
||
|
config.data.split_ratio,
|
||
|
config.data.except_inout,
|
||
|
config.data.triu_adj,
|
||
|
config.data.connect_prev,
|
||
|
'train',
|
||
|
config.data.label_list,
|
||
|
config.data.tg_dataset,
|
||
|
config.data.dataset_idx,
|
||
|
num_train,
|
||
|
node_rule_type=config.data.node_rule_type)
|
||
|
eval_dataset = NASDataset(
|
||
|
config.data.root,
|
||
|
config.data.split_ratio,
|
||
|
config.data.except_inout,
|
||
|
config.data.triu_adj,
|
||
|
config.data.connect_prev,
|
||
|
'eval',
|
||
|
config.data.label_list,
|
||
|
config.data.tg_dataset,
|
||
|
config.data.dataset_idx,
|
||
|
num_train,
|
||
|
node_rule_type=config.data.node_rule_type)
|
||
|
|
||
|
test_dataset = NASDataset(
|
||
|
config.data.root,
|
||
|
config.data.split_ratio,
|
||
|
config.data.except_inout,
|
||
|
config.data.triu_adj,
|
||
|
config.data.connect_prev,
|
||
|
'test',
|
||
|
config.data.label_list,
|
||
|
config.data.tg_dataset,
|
||
|
config.data.dataset_idx,
|
||
|
num_train,
|
||
|
node_rule_type=config.data.node_rule_type)
|
||
|
|
||
|
|
||
|
return train_dataset, eval_dataset, test_dataset
|
||
|
|
||
|
|
||
|
def get_meta_dataset(config):
|
||
|
database = MetaTrainDatabaseOFA
|
||
|
data_path = PROCESSED_DATA_PATH
|
||
|
|
||
|
train_dataset = database(
|
||
|
data_path,
|
||
|
config.model.num_sample,
|
||
|
config.data.label_list,
|
||
|
True,
|
||
|
config.data.except_inout,
|
||
|
config.data.triu_adj,
|
||
|
config.data.connect_prev,
|
||
|
'train')
|
||
|
eval_dataset = database(
|
||
|
data_path,
|
||
|
config.model.num_sample,
|
||
|
config.data.label_list,
|
||
|
True,
|
||
|
config.data.except_inout,
|
||
|
config.data.triu_adj,
|
||
|
config.data.connect_prev,
|
||
|
'val')
|
||
|
# test_dataset = MetaTestDataset()
|
||
|
test_dataset = None
|
||
|
return train_dataset, eval_dataset, test_dataset
|
||
|
|
||
|
def get_meta_dataloader(config ,train_dataset, eval_dataset, test_dataset):
|
||
|
if config.data.name == 'ofa':
|
||
|
train_loader = DataLoader(dataset=train_dataset,
|
||
|
batch_size=config.training.batch_size,
|
||
|
shuffle=True,)
|
||
|
# collate_fn=collate_fn_ofa)
|
||
|
eval_loader = DataLoader(dataset=eval_dataset,
|
||
|
batch_size=config.training.batch_size,)
|
||
|
# collate_fn=collate_fn_ofa)
|
||
|
else:
|
||
|
train_loader = DataLoader(dataset=train_dataset,
|
||
|
batch_size=config.training.batch_size,
|
||
|
shuffle=True)
|
||
|
eval_loader = DataLoader(dataset=eval_dataset,
|
||
|
batch_size=config.training.batch_size,
|
||
|
shuffle=False)
|
||
|
# test_loader = DataLoader(dataset=test_dataset,
|
||
|
# batch_size=config.training.batch_size,
|
||
|
# shuffle=False)
|
||
|
test_loader = None
|
||
|
return train_loader, eval_loader, test_loader
|
||
|
|
||
|
|
||
|
class MetaTestDataset(Dataset):
|
||
|
def __init__(self, data_path, data_name, num_sample, num_class=None):
|
||
|
self.num_sample = num_sample
|
||
|
self.data_name = data_name
|
||
|
|
||
|
num_class_dict = {
|
||
|
'cifar100': 100,
|
||
|
'cifar10': 10,
|
||
|
'mnist': 10,
|
||
|
'svhn': 10,
|
||
|
'aircraft': 30,
|
||
|
'pets': 37
|
||
|
}
|
||
|
|
||
|
if num_class is not None:
|
||
|
self.num_class = num_class
|
||
|
else:
|
||
|
self.num_class = num_class_dict[data_name]
|
||
|
self.x = torch.load(os.path.join(data_path, f'aircraft100bylabel.pt' if 'ofa' in data_path and data_name == 'aircraft' else f'{data_name}bylabel.pt' ))
|
||
|
|
||
|
def __len__(self):
|
||
|
return 1000000
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
data = []
|
||
|
classes = list(range(self.num_class))
|
||
|
for cls in classes:
|
||
|
cx = self.x[cls][0]
|
||
|
ridx = torch.randperm(len(cx))
|
||
|
data.append(cx[ridx[:self.num_sample]])
|
||
|
x = torch.cat(data)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class MetaTrainDatabaseOFA(Dataset):
|
||
|
# def __init__(self, data_path, num_sample, is_pred=False):
|
||
|
def __init__(
|
||
|
self,
|
||
|
data_path,
|
||
|
num_sample,
|
||
|
label_list,
|
||
|
is_pred=True,
|
||
|
except_inout=False,
|
||
|
triu_adj=True,
|
||
|
connect_prev=False,
|
||
|
mode='train'):
|
||
|
|
||
|
self.ops_decoder = list(OPS.keys())
|
||
|
self.mode = mode
|
||
|
self.acc_norm = True
|
||
|
self.num_sample = num_sample
|
||
|
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
|
||
|
|
||
|
if is_pred:
|
||
|
self.dpath = f'{data_path}/predictor/processed/'
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
self.dname = 'database_219152_14.0K'
|
||
|
data = torch.load(self.dpath + f'{self.dname}_{self.mode}.pt')
|
||
|
self.net = data['net']
|
||
|
self.x_list = []
|
||
|
self.adj_list = []
|
||
|
self.arch_str_list = []
|
||
|
for net in self.net:
|
||
|
x, adj = get_x_adj_from_opsdict_ofa(net)
|
||
|
# ---------- matrix ---------- #
|
||
|
self.x_list.append(x)
|
||
|
self.adj_list.append(torch.tensor(adj))
|
||
|
# ---------- arch_str ---------- #
|
||
|
self.arch_str_list.append(get_string_from_onehot_x(x))
|
||
|
# ---------- labels ---------- #
|
||
|
self.label_list = label_list
|
||
|
if self.label_list is not None:
|
||
|
self.flops_list = data['flops']
|
||
|
self.params_list = None
|
||
|
self.latency_list = None
|
||
|
|
||
|
self.acc_list = data['acc']
|
||
|
self.mean = data['mean']
|
||
|
self.std = data['std']
|
||
|
self.task_lst = data['class']
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.acc_list)
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
data = []
|
||
|
classes = self.task_lst[index]
|
||
|
acc = self.acc_list[index]
|
||
|
graph = self.net[index]
|
||
|
|
||
|
# ---------- x -----------
|
||
|
x = self.x_list[index]
|
||
|
# ---------- adj ----------
|
||
|
adj = self.adj_list[index]
|
||
|
acc = self.acc_list[index]
|
||
|
|
||
|
for i, cls in enumerate(classes):
|
||
|
cx = self.x[cls.item()][0]
|
||
|
ridx = torch.randperm(len(cx))
|
||
|
data.append(cx[ridx[:self.num_sample]])
|
||
|
task = torch.cat(data)
|
||
|
if self.acc_norm:
|
||
|
acc = ((acc - self.mean) / self.std) / 100.0
|
||
|
else:
|
||
|
acc = acc / 100.0
|
||
|
|
||
|
label_dict = {}
|
||
|
if self.label_list is not None:
|
||
|
assert type(self.label_list) == list
|
||
|
for label in self.label_list:
|
||
|
if label == 'meta-acc':
|
||
|
label_dict[f"{label}"] = acc
|
||
|
else:
|
||
|
raise ValueError
|
||
|
return x, adj, label_dict, task
|
||
|
|
||
|
|
||
|
class OFADataset(Dataset):
|
||
|
def __init__(
|
||
|
self,
|
||
|
data_path,
|
||
|
split_ratio=0.8,
|
||
|
except_inout=False,
|
||
|
triu_adj=True,
|
||
|
connect_prev=False,
|
||
|
mode='train',
|
||
|
label_list=None,
|
||
|
tg_dataset=None,
|
||
|
dataset_idx='random',
|
||
|
num_train=None,
|
||
|
node_rule_type=None):
|
||
|
|
||
|
# ---------- entire dataset ---------- #
|
||
|
self.data = torch.load(data_path)
|
||
|
self.except_inout = except_inout
|
||
|
self.triu_adj = triu_adj
|
||
|
self.connect_prev = connect_prev
|
||
|
self.node_rule_type = node_rule_type
|
||
|
|
||
|
# ---------- x ---------- #
|
||
|
self.x_list = self.data['x_none2zero']
|
||
|
|
||
|
# ---------- adj ---------- #
|
||
|
assert self.connect_prev == False
|
||
|
self.n_adj = len(self.data['node_type'][0])
|
||
|
const_adj = self.get_not_connect_prev_adj()
|
||
|
self.adj_list = [const_adj] * len(self.x_list)
|
||
|
|
||
|
# ---------- arch_str ---------- #
|
||
|
self.arch_str_list = self.data['net_setting']
|
||
|
# ---------- labels ---------- #
|
||
|
self.label_list = label_list
|
||
|
if self.label_list is not None:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
# ----------- split dataset ---------- #
|
||
|
self.ds_idx = list(torch.load(SCORE_MODEL_DATA_IDX_PATH))
|
||
|
|
||
|
self.split_ratio = split_ratio
|
||
|
if num_train is None:
|
||
|
num_train = int(len(self.x_list) * self.split_ratio)
|
||
|
num_test = len(self.x_list) - num_train
|
||
|
else:
|
||
|
num_train = num_train
|
||
|
num_test = len(self.x_list) - num_train
|
||
|
# ----------- compute mean and std w/ training dataset ---------- #
|
||
|
if self.label_list is not None:
|
||
|
self.train_idx_list = self.ds_idx[:num_train]
|
||
|
print('Computing mean and std of the training set...')
|
||
|
from collections import defaultdict
|
||
|
LABEL_TO_MEAN_STD = defaultdict(dict)
|
||
|
assert type(self.label_list) == list
|
||
|
for label in self.label_list:
|
||
|
if label == 'test-acc':
|
||
|
self.test_acc_list_tr = [self.test_acc_list[i] for i in self.train_idx_list]
|
||
|
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.test_acc_list_tr))
|
||
|
elif label == 'flops':
|
||
|
self.flops_list_tr = [self.flops_list[i] for i in self.train_idx_list]
|
||
|
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.flops_list_tr))
|
||
|
elif label == 'params':
|
||
|
self.params_list_tr = [self.params_list[i] for i in self.train_idx_list]
|
||
|
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.params_list_tr))
|
||
|
elif label == 'latency':
|
||
|
self.latency_list_tr = [self.latency_list[i] for i in self.train_idx_list]
|
||
|
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.latency_list_tr))
|
||
|
else:
|
||
|
raise ValueError
|
||
|
|
||
|
self.mode = mode
|
||
|
if self.mode in ['train']:
|
||
|
self.idx_list = self.ds_idx[:num_train]
|
||
|
elif self.mode in ['eval']:
|
||
|
self.idx_list = self.ds_idx[:num_test]
|
||
|
elif self.mode in ['test']:
|
||
|
self.idx_list = self.ds_idx[num_train:]
|
||
|
|
||
|
self.x_list_ = [self.x_list[i] for i in self.idx_list]
|
||
|
self.adj_list_ = [self.adj_list[i] for i in self.idx_list]
|
||
|
self.arch_str_list_ = [self.arch_str_list[i] for i in self.idx_list]
|
||
|
|
||
|
if self.label_list is not None:
|
||
|
assert type(self.label_list) == list
|
||
|
for label in self.label_list:
|
||
|
if label == 'test-acc':
|
||
|
self.test_acc_list_ = [self.test_acc_list[i] for i in self.idx_list]
|
||
|
self.test_acc_list_ = self.normalize(self.test_acc_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
|
||
|
elif label == 'flops':
|
||
|
self.flops_list_ = [self.flops_list[i] for i in self.idx_list]
|
||
|
self.flops_list_ = self.normalize(self.flops_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
|
||
|
elif label == 'params':
|
||
|
self.params_list_ = [self.params_list[i] for i in self.idx_list]
|
||
|
self.params_list_ = self.normalize(self.params_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
|
||
|
elif label == 'latency':
|
||
|
self.latency_list_ = [self.latency_list[i] for i in self.idx_list]
|
||
|
self.latency_list_ = self.normalize(self.latency_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
|
||
|
else:
|
||
|
raise ValueError
|
||
|
|
||
|
def normalize(self, original, mean, std):
|
||
|
return [(i-mean)/std for i in original]
|
||
|
|
||
|
def get_not_connect_prev_adj(self):
|
||
|
_adj = torch.zeros(self.n_adj, self.n_adj)
|
||
|
for i in range(self.n_adj-1):
|
||
|
_adj[i, i+1] = 1
|
||
|
_adj = _adj.to(torch.float32).to('cpu') # torch.tensor(_adj, dtype=torch.float32, device=torch.device('cpu'))
|
||
|
# if self.except_inout:
|
||
|
# _adj = _adj[1:-1, 1:-1]
|
||
|
return _adj
|
||
|
|
||
|
@property
|
||
|
def adj(self):
|
||
|
return self.adj_list_[0]
|
||
|
|
||
|
# @property
|
||
|
def mask(self, algo='floyd', data='ofa'):
|
||
|
from utils import aug_mask
|
||
|
return aug_mask(self.adj, algo=algo, data=data)[0]
|
||
|
|
||
|
def get_unnoramlized_entire_data(self, label, tg_dataset):
|
||
|
entire_test_acc_list = self.data['test-acc'][tg_dataset]
|
||
|
entire_flops_list = self.data['flops'][tg_dataset]
|
||
|
entire_params_list = self.data['params'][tg_dataset]
|
||
|
entire_latency_list = self.data['latency'][tg_dataset]
|
||
|
|
||
|
if label == 'test-acc':
|
||
|
return entire_test_acc_list
|
||
|
elif label == 'flops':
|
||
|
return entire_flops_list
|
||
|
elif label == 'params':
|
||
|
return entire_params_list
|
||
|
elif label == 'latency':
|
||
|
return entire_latency_list
|
||
|
else:
|
||
|
raise ValueError
|
||
|
|
||
|
|
||
|
def get_unnoramlized_data(self, label, tg_dataset):
|
||
|
entire_test_acc_list = self.data['test-acc'][tg_dataset]
|
||
|
entire_flops_list = self.data['flops'][tg_dataset]
|
||
|
entire_params_list = self.data['params'][tg_dataset]
|
||
|
entire_latency_list = self.data['latency'][tg_dataset]
|
||
|
|
||
|
if label == 'test-acc':
|
||
|
return [entire_test_acc_list[i] for i in self.idx_list]
|
||
|
elif label == 'flops':
|
||
|
return [entire_flops_list[i] for i in self.idx_list]
|
||
|
elif label == 'params':
|
||
|
return [entire_params_list[i] for i in self.idx_list]
|
||
|
elif label == 'latency':
|
||
|
return [entire_latency_list[i] for i in self.idx_list]
|
||
|
else:
|
||
|
raise ValueError
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.x_list_)
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
|
||
|
label_dict = {}
|
||
|
if self.label_list is not None:
|
||
|
assert type(self.label_list) == list
|
||
|
for label in self.label_list:
|
||
|
if label == 'test-acc':
|
||
|
label_dict[f"{label}"] = self.test_acc_list_[index]
|
||
|
elif label == 'flops':
|
||
|
label_dict[f"{label}"] = self.flops_list_[index]
|
||
|
elif label == 'params':
|
||
|
label_dict[f"{label}"] = self.params_list_[index]
|
||
|
elif label == 'latency':
|
||
|
label_dict[f"{label}"] = self.latency_list_[index]
|
||
|
else:
|
||
|
raise ValueError
|
||
|
|
||
|
return self.x_list_[index], self.adj_list_[index], label_dict
|