diffusionNAG/MobileNetV3/datasets_nas.py
2024-03-15 14:38:51 +00:00

493 lines
18 KiB
Python

from __future__ import print_function
import torch
import os
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torch_geometric.utils import to_networkx
from analysis.arch_functions import get_x_adj_from_opsdict_ofa, get_string_from_onehot_x
from all_path import PROCESSED_DATA_PATH, SCORE_MODEL_DATA_IDX_PATH
from analysis.arch_functions import OPS
def get_data_scaler(config):
"""Data normalizer. Assume data are always in [0, 1]."""
if config.data.centered:
# Rescale to [-1, 1]
return lambda x: x * 2. - 1.
else:
return lambda x: x
def get_data_inverse_scaler(config):
"""Inverse data normalizer."""
if config.data.centered:
# Rescale [-1, 1] to [0, 1]
return lambda x: (x + 1.) / 2.
else:
return lambda x: x
def networkx_graphs(dataset):
return [to_networkx(dataset[i], to_undirected=False, remove_self_loops=True) for i in range(len(dataset))]
def get_dataloader(config, train_dataset, eval_dataset, test_dataset):
train_loader = DataLoader(dataset=train_dataset,
batch_size=config.training.batch_size,
shuffle=True,
collate_fn=collate_fn_ofa if config.model_type == 'meta_predictor' else None)
eval_loader = DataLoader(dataset=eval_dataset,
batch_size=config.training.batch_size,
shuffle=False,
collate_fn=collate_fn_ofa if config.model_type == 'meta_predictor' else None)
test_loader = DataLoader(dataset=test_dataset,
batch_size=config.training.batch_size,
shuffle=False,
collate_fn=collate_fn_ofa if config.model_type == 'meta_predictor' else None)
return train_loader, eval_loader, test_loader
def get_dataloader_iter(config, train_dataset, eval_dataset, test_dataset):
train_loader = DataLoader(dataset=train_dataset,
batch_size=config.training.batch_size if len(train_dataset) > config.training.batch_size else len(train_dataset),
# batch_size=8,
shuffle=True,)
eval_loader = DataLoader(dataset=eval_dataset,
batch_size=config.training.batch_size if len(eval_dataset) > config.training.batch_size else len(eval_dataset),
# batch_size=8,
shuffle=False,)
test_loader = DataLoader(dataset=test_dataset,
batch_size=config.training.batch_size if len(test_dataset) > config.training.batch_size else len(test_dataset),
# batch_size=8,
shuffle=False,)
return train_loader, eval_loader, test_loader
def is_triu(mat):
is_triu_ = np.allclose(mat, np.triu(mat))
return is_triu_
def collate_fn_ofa(batch):
# x, adj, label_dict, task
x = torch.stack([item[0] for item in batch])
adj = torch.stack([item[1] for item in batch])
label_dict = {}
for item in batch:
for k, v in item[2].items():
if not k in label_dict.keys():
label_dict[k] = []
label_dict[k].append(v)
for k, v in label_dict.items():
label_dict[k] = torch.tensor(v)
task = [item[3] for item in batch]
return x, adj, label_dict, task
def get_dataset(config):
"""Create data loaders for training and evaluation.
Args:
config: A ml_collection.ConfigDict parsed from config files.
Returns:
train_ds, eval_ds, test_ds
"""
num_train = config.data.num_train if 'num_train' in config.data else None
NASDataset = OFADataset
train_dataset = NASDataset(
config.data.root,
config.data.split_ratio,
config.data.except_inout,
config.data.triu_adj,
config.data.connect_prev,
'train',
config.data.label_list,
config.data.tg_dataset,
config.data.dataset_idx,
num_train,
node_rule_type=config.data.node_rule_type)
eval_dataset = NASDataset(
config.data.root,
config.data.split_ratio,
config.data.except_inout,
config.data.triu_adj,
config.data.connect_prev,
'eval',
config.data.label_list,
config.data.tg_dataset,
config.data.dataset_idx,
num_train,
node_rule_type=config.data.node_rule_type)
test_dataset = NASDataset(
config.data.root,
config.data.split_ratio,
config.data.except_inout,
config.data.triu_adj,
config.data.connect_prev,
'test',
config.data.label_list,
config.data.tg_dataset,
config.data.dataset_idx,
num_train,
node_rule_type=config.data.node_rule_type)
return train_dataset, eval_dataset, test_dataset
def get_meta_dataset(config):
database = MetaTrainDatabaseOFA
data_path = PROCESSED_DATA_PATH
train_dataset = database(
data_path,
config.model.num_sample,
config.data.label_list,
True,
config.data.except_inout,
config.data.triu_adj,
config.data.connect_prev,
'train')
eval_dataset = database(
data_path,
config.model.num_sample,
config.data.label_list,
True,
config.data.except_inout,
config.data.triu_adj,
config.data.connect_prev,
'val')
# test_dataset = MetaTestDataset()
test_dataset = None
return train_dataset, eval_dataset, test_dataset
def get_meta_dataloader(config ,train_dataset, eval_dataset, test_dataset):
if config.data.name == 'ofa':
train_loader = DataLoader(dataset=train_dataset,
batch_size=config.training.batch_size,
shuffle=True,)
# collate_fn=collate_fn_ofa)
eval_loader = DataLoader(dataset=eval_dataset,
batch_size=config.training.batch_size,)
# collate_fn=collate_fn_ofa)
else:
train_loader = DataLoader(dataset=train_dataset,
batch_size=config.training.batch_size,
shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset,
batch_size=config.training.batch_size,
shuffle=False)
# test_loader = DataLoader(dataset=test_dataset,
# batch_size=config.training.batch_size,
# shuffle=False)
test_loader = None
return train_loader, eval_loader, test_loader
class MetaTestDataset(Dataset):
def __init__(self, data_path, data_name, num_sample, num_class=None):
self.num_sample = num_sample
self.data_name = data_name
num_class_dict = {
'cifar100': 100,
'cifar10': 10,
'mnist': 10,
'svhn': 10,
'aircraft': 30,
'pets': 37
}
if num_class is not None:
self.num_class = num_class
else:
self.num_class = num_class_dict[data_name]
self.x = torch.load(os.path.join(data_path, f'aircraft100bylabel.pt' if 'ofa' in data_path and data_name == 'aircraft' else f'{data_name}bylabel.pt' ))
def __len__(self):
return 1000000
def __getitem__(self, index):
data = []
classes = list(range(self.num_class))
for cls in classes:
cx = self.x[cls][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
x = torch.cat(data)
return x
class MetaTrainDatabaseOFA(Dataset):
# def __init__(self, data_path, num_sample, is_pred=False):
def __init__(
self,
data_path,
num_sample,
label_list,
is_pred=True,
except_inout=False,
triu_adj=True,
connect_prev=False,
mode='train'):
self.ops_decoder = list(OPS.keys())
self.mode = mode
self.acc_norm = True
self.num_sample = num_sample
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
if is_pred:
self.dpath = f'{data_path}/predictor/processed/'
else:
raise NotImplementedError
self.dname = 'database_219152_14.0K'
data = torch.load(self.dpath + f'{self.dname}_{self.mode}.pt')
self.net = data['net']
self.x_list = []
self.adj_list = []
self.arch_str_list = []
for net in self.net:
x, adj = get_x_adj_from_opsdict_ofa(net)
# ---------- matrix ---------- #
self.x_list.append(x)
self.adj_list.append(torch.tensor(adj))
# ---------- arch_str ---------- #
self.arch_str_list.append(get_string_from_onehot_x(x))
# ---------- labels ---------- #
self.label_list = label_list
if self.label_list is not None:
self.flops_list = data['flops']
self.params_list = None
self.latency_list = None
self.acc_list = data['acc']
self.mean = data['mean']
self.std = data['std']
self.task_lst = data['class']
def __len__(self):
return len(self.acc_list)
def __getitem__(self, index):
data = []
classes = self.task_lst[index]
acc = self.acc_list[index]
graph = self.net[index]
# ---------- x -----------
x = self.x_list[index]
# ---------- adj ----------
adj = self.adj_list[index]
acc = self.acc_list[index]
for i, cls in enumerate(classes):
cx = self.x[cls.item()][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
task = torch.cat(data)
if self.acc_norm:
acc = ((acc - self.mean) / self.std) / 100.0
else:
acc = acc / 100.0
label_dict = {}
if self.label_list is not None:
assert type(self.label_list) == list
for label in self.label_list:
if label == 'meta-acc':
label_dict[f"{label}"] = acc
else:
raise ValueError
return x, adj, label_dict, task
class OFADataset(Dataset):
def __init__(
self,
data_path,
split_ratio=0.8,
except_inout=False,
triu_adj=True,
connect_prev=False,
mode='train',
label_list=None,
tg_dataset=None,
dataset_idx='random',
num_train=None,
node_rule_type=None):
# ---------- entire dataset ---------- #
self.data = torch.load(data_path)
self.except_inout = except_inout
self.triu_adj = triu_adj
self.connect_prev = connect_prev
self.node_rule_type = node_rule_type
# ---------- x ---------- #
self.x_list = self.data['x_none2zero']
# ---------- adj ---------- #
assert self.connect_prev == False
self.n_adj = len(self.data['node_type'][0])
const_adj = self.get_not_connect_prev_adj()
self.adj_list = [const_adj] * len(self.x_list)
# ---------- arch_str ---------- #
self.arch_str_list = self.data['net_setting']
# ---------- labels ---------- #
self.label_list = label_list
if self.label_list is not None:
raise NotImplementedError
# ----------- split dataset ---------- #
self.ds_idx = list(torch.load(SCORE_MODEL_DATA_IDX_PATH))
self.split_ratio = split_ratio
if num_train is None:
num_train = int(len(self.x_list) * self.split_ratio)
num_test = len(self.x_list) - num_train
else:
num_train = num_train
num_test = len(self.x_list) - num_train
# ----------- compute mean and std w/ training dataset ---------- #
if self.label_list is not None:
self.train_idx_list = self.ds_idx[:num_train]
print('Computing mean and std of the training set...')
from collections import defaultdict
LABEL_TO_MEAN_STD = defaultdict(dict)
assert type(self.label_list) == list
for label in self.label_list:
if label == 'test-acc':
self.test_acc_list_tr = [self.test_acc_list[i] for i in self.train_idx_list]
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.test_acc_list_tr))
elif label == 'flops':
self.flops_list_tr = [self.flops_list[i] for i in self.train_idx_list]
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.flops_list_tr))
elif label == 'params':
self.params_list_tr = [self.params_list[i] for i in self.train_idx_list]
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.params_list_tr))
elif label == 'latency':
self.latency_list_tr = [self.latency_list[i] for i in self.train_idx_list]
LABEL_TO_MEAN_STD[label]['std'], LABEL_TO_MEAN_STD[label]['mean'] = torch.std_mean(torch.tensor(self.latency_list_tr))
else:
raise ValueError
self.mode = mode
if self.mode in ['train']:
self.idx_list = self.ds_idx[:num_train]
elif self.mode in ['eval']:
self.idx_list = self.ds_idx[:num_test]
elif self.mode in ['test']:
self.idx_list = self.ds_idx[num_train:]
self.x_list_ = [self.x_list[i] for i in self.idx_list]
self.adj_list_ = [self.adj_list[i] for i in self.idx_list]
self.arch_str_list_ = [self.arch_str_list[i] for i in self.idx_list]
if self.label_list is not None:
assert type(self.label_list) == list
for label in self.label_list:
if label == 'test-acc':
self.test_acc_list_ = [self.test_acc_list[i] for i in self.idx_list]
self.test_acc_list_ = self.normalize(self.test_acc_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
elif label == 'flops':
self.flops_list_ = [self.flops_list[i] for i in self.idx_list]
self.flops_list_ = self.normalize(self.flops_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
elif label == 'params':
self.params_list_ = [self.params_list[i] for i in self.idx_list]
self.params_list_ = self.normalize(self.params_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
elif label == 'latency':
self.latency_list_ = [self.latency_list[i] for i in self.idx_list]
self.latency_list_ = self.normalize(self.latency_list_, LABEL_TO_MEAN_STD[label]['mean'], LABEL_TO_MEAN_STD[label]['std'])
else:
raise ValueError
def normalize(self, original, mean, std):
return [(i-mean)/std for i in original]
def get_not_connect_prev_adj(self):
_adj = torch.zeros(self.n_adj, self.n_adj)
for i in range(self.n_adj-1):
_adj[i, i+1] = 1
_adj = _adj.to(torch.float32).to('cpu') # torch.tensor(_adj, dtype=torch.float32, device=torch.device('cpu'))
# if self.except_inout:
# _adj = _adj[1:-1, 1:-1]
return _adj
@property
def adj(self):
return self.adj_list_[0]
# @property
def mask(self, algo='floyd', data='ofa'):
from utils import aug_mask
return aug_mask(self.adj, algo=algo, data=data)[0]
def get_unnoramlized_entire_data(self, label, tg_dataset):
entire_test_acc_list = self.data['test-acc'][tg_dataset]
entire_flops_list = self.data['flops'][tg_dataset]
entire_params_list = self.data['params'][tg_dataset]
entire_latency_list = self.data['latency'][tg_dataset]
if label == 'test-acc':
return entire_test_acc_list
elif label == 'flops':
return entire_flops_list
elif label == 'params':
return entire_params_list
elif label == 'latency':
return entire_latency_list
else:
raise ValueError
def get_unnoramlized_data(self, label, tg_dataset):
entire_test_acc_list = self.data['test-acc'][tg_dataset]
entire_flops_list = self.data['flops'][tg_dataset]
entire_params_list = self.data['params'][tg_dataset]
entire_latency_list = self.data['latency'][tg_dataset]
if label == 'test-acc':
return [entire_test_acc_list[i] for i in self.idx_list]
elif label == 'flops':
return [entire_flops_list[i] for i in self.idx_list]
elif label == 'params':
return [entire_params_list[i] for i in self.idx_list]
elif label == 'latency':
return [entire_latency_list[i] for i in self.idx_list]
else:
raise ValueError
def __len__(self):
return len(self.x_list_)
def __getitem__(self, index):
label_dict = {}
if self.label_list is not None:
assert type(self.label_list) == list
for label in self.label_list:
if label == 'test-acc':
label_dict[f"{label}"] = self.test_acc_list_[index]
elif label == 'flops':
label_dict[f"{label}"] = self.flops_list_[index]
elif label == 'params':
label_dict[f"{label}"] = self.params_list_[index]
elif label == 'latency':
label_dict[f"{label}"] = self.latency_list_[index]
else:
raise ValueError
return self.x_list_[index], self.adj_list_[index], label_dict