150 lines
4.6 KiB
Python
150 lines
4.6 KiB
Python
###########################################################################################
|
|
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
|
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
|
###########################################################################################
|
|
from __future__ import print_function
|
|
import os
|
|
import torch
|
|
from tqdm import tqdm
|
|
from torch.utils.data import Dataset
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
def get_meta_train_loader(batch_size, data_path, num_sample, is_pred=False):
|
|
dataset = MetaTrainDatabase(data_path, num_sample, is_pred)
|
|
print(f'==> The number of tasks for meta-training: {len(dataset)}')
|
|
|
|
loader = DataLoader(dataset=dataset,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
num_workers=1,
|
|
collate_fn=collate_fn)
|
|
return loader
|
|
|
|
|
|
def get_meta_test_loader(data_path, data_name, num_class=None, is_pred=False):
|
|
dataset = MetaTestDataset(data_path, data_name, num_class)
|
|
print(f'==> Meta-Test dataset {data_name}')
|
|
|
|
loader = DataLoader(dataset=dataset,
|
|
batch_size=100,
|
|
shuffle=False,
|
|
num_workers=1)
|
|
return loader
|
|
|
|
|
|
class MetaTrainDatabase(Dataset):
|
|
def __init__(self, data_path, num_sample, is_pred=False):
|
|
self.mode = 'train'
|
|
self.acc_norm = True
|
|
self.num_sample = num_sample
|
|
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
|
|
|
|
self.dpath = '{}/{}/processed/'.format(data_path, 'predictor' if is_pred else 'generator')
|
|
self.dname = f'database_219152_14.0K'
|
|
|
|
if not os.path.exists(self.dpath + f'{self.dname}_train.pt'):
|
|
raise ValueError('')
|
|
database = torch.load(self.dpath + f'{self.dname}.pt')
|
|
|
|
rand_idx = torch.randperm(len(database))
|
|
test_len = int(len(database) * 0.15)
|
|
idxlst = {'test': rand_idx[:test_len],
|
|
'valid': rand_idx[test_len:2 * test_len],
|
|
'train': rand_idx[2 * test_len:]}
|
|
|
|
for m in ['train', 'valid', 'test']:
|
|
acc, graph, cls, net, flops = [], [], [], [], []
|
|
for idx in tqdm(idxlst[m].tolist(), desc=f'data-{m}'):
|
|
acc.append(database[idx]['top1'])
|
|
net.append(database[idx]['net'])
|
|
cls.append(database[idx]['class'])
|
|
flops.append(database[idx]['flops'])
|
|
if m == 'train':
|
|
mean = torch.mean(torch.tensor(acc)).item()
|
|
std = torch.std(torch.tensor(acc)).item()
|
|
torch.save({'acc': acc,
|
|
'class': cls,
|
|
'net': net,
|
|
'flops': flops,
|
|
'mean': mean,
|
|
'std': std},
|
|
self.dpath + f'{self.dname}_{m}.pt')
|
|
|
|
self.set_mode(self.mode)
|
|
|
|
def set_mode(self, mode):
|
|
self.mode = mode
|
|
data = torch.load(self.dpath + f'{self.dname}_{self.mode}.pt')
|
|
self.acc = data['acc']
|
|
self.cls = data['class']
|
|
self.net = data['net']
|
|
self.flops = data['flops']
|
|
self.mean = data['mean']
|
|
self.std = data['std']
|
|
|
|
def __len__(self):
|
|
return len(self.acc)
|
|
|
|
def __getitem__(self, index):
|
|
data = []
|
|
classes = self.cls[index]
|
|
acc = self.acc[index]
|
|
graph = self.net[index]
|
|
|
|
for i, cls in enumerate(classes):
|
|
cx = self.x[cls.item()][0]
|
|
ridx = torch.randperm(len(cx))
|
|
data.append(cx[ridx[:self.num_sample]])
|
|
x = torch.cat(data)
|
|
if self.acc_norm:
|
|
acc = ((acc - self.mean) / self.std) / 100.0
|
|
else:
|
|
acc = acc / 100.0
|
|
return x, graph, torch.tensor(acc).view(1, 1)
|
|
|
|
|
|
class MetaTestDataset(Dataset):
|
|
def __init__(self, data_path, data_name, num_sample, num_class=None):
|
|
self.num_sample = num_sample
|
|
self.data_name = data_name
|
|
if data_name == 'aircraft':
|
|
data_name = 'aircraft100'
|
|
num_class_dict = {
|
|
'cifar100': 100,
|
|
'cifar10': 10,
|
|
'mnist': 10,
|
|
'aircraft100': 30,
|
|
'svhn': 10,
|
|
'pets': 37
|
|
}
|
|
# 'aircraft30': 30,
|
|
# 'aircraft100': 100,
|
|
|
|
if num_class is not None:
|
|
self.num_class = num_class
|
|
else:
|
|
self.num_class = num_class_dict[data_name]
|
|
|
|
self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt'))
|
|
|
|
def __len__(self):
|
|
return 1000000
|
|
|
|
def __getitem__(self, index):
|
|
data = []
|
|
classes = list(range(self.num_class))
|
|
for cls in classes:
|
|
cx = self.x[cls][0]
|
|
ridx = torch.randperm(len(cx))
|
|
data.append(cx[ridx[:self.num_sample]])
|
|
x = torch.cat(data)
|
|
return x
|
|
|
|
|
|
def collate_fn(batch):
|
|
# x = torch.stack([item[0] for item in batch])
|
|
# graph = [item[1] for item in batch]
|
|
# acc = torch.stack([item[2] for item in batch])
|
|
return batch
|