diffusionNAG/NAS-Bench-201/main_exp/transfer_nag/loader.py
2024-03-15 14:38:51 +00:00

131 lines
4.1 KiB
Python

###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from __future__ import print_function
import os
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
def get_meta_train_loader(batch_size, data_path, num_sample, is_pred=True):
dataset = MetaTrainDatabase(data_path, num_sample, is_pred)
print(f'==> The number of tasks for meta-training: {len(dataset)}')
loader = DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
collate_fn=collate_fn)
return loader
def get_meta_test_loader(data_path, data_name, num_class=None, is_pred=False):
dataset = MetaTestDataset(data_path, data_name, num_class)
print(f'==> Meta-Test dataset {data_name}')
loader = DataLoader(dataset=dataset,
batch_size=100,
shuffle=False,
num_workers=0)
return loader
class MetaTrainDatabase(Dataset):
def __init__(self, data_path, num_sample, is_pred=True):
self.mode = 'train'
self.acc_norm = True
self.num_sample = num_sample
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
mtr_data_path = os.path.join(
data_path, 'meta_train_tasks_predictor.pt')
idx_path = os.path.join(
data_path, 'meta_train_tasks_predictor_idx.pt')
data = torch.load(mtr_data_path)
self.acc = data['acc']
self.task = data['task']
self.graph = data['g']
random_idx_lst = torch.load(idx_path)
self.idx_lst = {}
self.idx_lst['valid'] = random_idx_lst[:400]
self.idx_lst['train'] = random_idx_lst[400:]
self.acc = torch.tensor(self.acc)
self.mean = torch.mean(self.acc[self.idx_lst['train']]).item()
self.std = torch.std(self.acc[self.idx_lst['train']]).item()
self.task_lst = torch.load(os.path.join(
data_path, 'meta_train_task_lst.pt'))
def set_mode(self, mode):
self.mode = mode
def __len__(self):
return len(self.idx_lst[self.mode])
def __getitem__(self, index):
data = []
ridx = self.idx_lst[self.mode]
tidx = self.task[ridx[index]]
classes = self.task_lst[tidx]
graph = self.graph[ridx[index]]
acc = self.acc[ridx[index]]
for cls in classes:
cx = self.x[cls-1][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
x = torch.cat(data)
if self.acc_norm:
acc = ((acc - self.mean) / self.std) / 100.0
else:
acc = acc / 100.0
return x, graph, acc
class MetaTestDataset(Dataset):
def __init__(self, data_path, data_name, num_sample, num_class=None):
self.num_sample = num_sample
self.data_name = data_name
num_class_dict = {
'cifar100': 100,
'cifar10': 10,
'mnist': 10,
'svhn': 10,
'aircraft': 30,
'pets': 37
}
if num_class is not None:
self.num_class = num_class
else:
self.num_class = num_class_dict[data_name]
self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt'))
def __len__(self):
return 1000000
def __getitem__(self, index):
data = []
classes = list(range(self.num_class))
for cls in classes:
cx = self.x[cls][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
x = torch.cat(data)
return x
def collate_fn(batch):
x = torch.stack([item[0] for item in batch])
graph = [item[1] for item in batch]
acc = torch.stack([item[2] for item in batch])
return [x, graph, acc]