131 lines
4.1 KiB
Python
131 lines
4.1 KiB
Python
###########################################################################################
|
|
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
|
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
|
###########################################################################################
|
|
from __future__ import print_function
|
|
import os
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
def get_meta_train_loader(batch_size, data_path, num_sample, is_pred=True):
|
|
dataset = MetaTrainDatabase(data_path, num_sample, is_pred)
|
|
print(f'==> The number of tasks for meta-training: {len(dataset)}')
|
|
|
|
loader = DataLoader(dataset=dataset,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
num_workers=0,
|
|
collate_fn=collate_fn)
|
|
return loader
|
|
|
|
|
|
def get_meta_test_loader(data_path, data_name, num_class=None, is_pred=False):
|
|
dataset = MetaTestDataset(data_path, data_name, num_class)
|
|
print(f'==> Meta-Test dataset {data_name}')
|
|
|
|
loader = DataLoader(dataset=dataset,
|
|
batch_size=100,
|
|
shuffle=False,
|
|
num_workers=0)
|
|
return loader
|
|
|
|
|
|
class MetaTrainDatabase(Dataset):
|
|
def __init__(self, data_path, num_sample, is_pred=True):
|
|
self.mode = 'train'
|
|
self.acc_norm = True
|
|
self.num_sample = num_sample
|
|
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
|
|
|
|
mtr_data_path = os.path.join(
|
|
data_path, 'meta_train_tasks_predictor.pt')
|
|
idx_path = os.path.join(
|
|
data_path, 'meta_train_tasks_predictor_idx.pt')
|
|
data = torch.load(mtr_data_path)
|
|
self.acc = data['acc']
|
|
self.task = data['task']
|
|
self.graph = data['g']
|
|
|
|
random_idx_lst = torch.load(idx_path)
|
|
self.idx_lst = {}
|
|
self.idx_lst['valid'] = random_idx_lst[:400]
|
|
self.idx_lst['train'] = random_idx_lst[400:]
|
|
self.acc = torch.tensor(self.acc)
|
|
self.mean = torch.mean(self.acc[self.idx_lst['train']]).item()
|
|
self.std = torch.std(self.acc[self.idx_lst['train']]).item()
|
|
self.task_lst = torch.load(os.path.join(
|
|
data_path, 'meta_train_task_lst.pt'))
|
|
|
|
|
|
def set_mode(self, mode):
|
|
self.mode = mode
|
|
|
|
|
|
def __len__(self):
|
|
return len(self.idx_lst[self.mode])
|
|
|
|
|
|
def __getitem__(self, index):
|
|
data = []
|
|
ridx = self.idx_lst[self.mode]
|
|
tidx = self.task[ridx[index]]
|
|
classes = self.task_lst[tidx]
|
|
graph = self.graph[ridx[index]]
|
|
acc = self.acc[ridx[index]]
|
|
for cls in classes:
|
|
cx = self.x[cls-1][0]
|
|
ridx = torch.randperm(len(cx))
|
|
data.append(cx[ridx[:self.num_sample]])
|
|
x = torch.cat(data)
|
|
if self.acc_norm:
|
|
acc = ((acc - self.mean) / self.std) / 100.0
|
|
else:
|
|
acc = acc / 100.0
|
|
return x, graph, acc
|
|
|
|
|
|
class MetaTestDataset(Dataset):
|
|
def __init__(self, data_path, data_name, num_sample, num_class=None):
|
|
self.num_sample = num_sample
|
|
self.data_name = data_name
|
|
|
|
num_class_dict = {
|
|
'cifar100': 100,
|
|
'cifar10': 10,
|
|
'mnist': 10,
|
|
'svhn': 10,
|
|
'aircraft': 30,
|
|
'pets': 37
|
|
}
|
|
|
|
if num_class is not None:
|
|
self.num_class = num_class
|
|
else:
|
|
self.num_class = num_class_dict[data_name]
|
|
|
|
self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt'))
|
|
|
|
|
|
def __len__(self):
|
|
return 1000000
|
|
|
|
|
|
def __getitem__(self, index):
|
|
data = []
|
|
classes = list(range(self.num_class))
|
|
for cls in classes:
|
|
cx = self.x[cls][0]
|
|
ridx = torch.randperm(len(cx))
|
|
data.append(cx[ridx[:self.num_sample]])
|
|
x = torch.cat(data)
|
|
return x
|
|
|
|
|
|
def collate_fn(batch):
|
|
x = torch.stack([item[0] for item in batch])
|
|
graph = [item[1] for item in batch]
|
|
acc = torch.stack([item[2] for item in batch])
|
|
return [x, graph, acc]
|