########################################################################################### # Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021 # Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021 ########################################################################################### from __future__ import print_function import os import torch from torch.utils.data import Dataset from torch.utils.data import DataLoader def get_meta_train_loader(batch_size, data_path, num_sample, is_pred=True): dataset = MetaTrainDatabase(data_path, num_sample, is_pred) print(f'==> The number of tasks for meta-training: {len(dataset)}') loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate_fn) return loader def get_meta_test_loader(data_path, data_name, num_class=None, is_pred=False): dataset = MetaTestDataset(data_path, data_name, num_class) print(f'==> Meta-Test dataset {data_name}') loader = DataLoader(dataset=dataset, batch_size=100, shuffle=False, num_workers=0) return loader class MetaTrainDatabase(Dataset): def __init__(self, data_path, num_sample, is_pred=True): self.mode = 'train' self.acc_norm = True self.num_sample = num_sample self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt')) mtr_data_path = os.path.join( data_path, 'meta_train_tasks_predictor.pt') idx_path = os.path.join( data_path, 'meta_train_tasks_predictor_idx.pt') data = torch.load(mtr_data_path) self.acc = data['acc'] self.task = data['task'] self.graph = data['g'] random_idx_lst = torch.load(idx_path) self.idx_lst = {} self.idx_lst['valid'] = random_idx_lst[:400] self.idx_lst['train'] = random_idx_lst[400:] self.acc = torch.tensor(self.acc) self.mean = torch.mean(self.acc[self.idx_lst['train']]).item() self.std = torch.std(self.acc[self.idx_lst['train']]).item() self.task_lst = torch.load(os.path.join( data_path, 'meta_train_task_lst.pt')) def set_mode(self, mode): self.mode = mode def __len__(self): return len(self.idx_lst[self.mode]) def __getitem__(self, index): data = [] ridx = self.idx_lst[self.mode] tidx = self.task[ridx[index]] classes = self.task_lst[tidx] graph = self.graph[ridx[index]] acc = self.acc[ridx[index]] for cls in classes: cx = self.x[cls-1][0] ridx = torch.randperm(len(cx)) data.append(cx[ridx[:self.num_sample]]) x = torch.cat(data) if self.acc_norm: acc = ((acc - self.mean) / self.std) / 100.0 else: acc = acc / 100.0 return x, graph, acc class MetaTestDataset(Dataset): def __init__(self, data_path, data_name, num_sample, num_class=None): self.num_sample = num_sample self.data_name = data_name num_class_dict = { 'cifar100': 100, 'cifar10': 10, 'mnist': 10, 'svhn': 10, 'aircraft': 30, 'pets': 37 } if num_class is not None: self.num_class = num_class else: self.num_class = num_class_dict[data_name] self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt')) def __len__(self): return 1000000 def __getitem__(self, index): data = [] classes = list(range(self.num_class)) for cls in classes: cx = self.x[cls][0] ridx = torch.randperm(len(cx)) data.append(cx[ridx[:self.num_sample]]) x = torch.cat(data) return x def collate_fn(batch): x = torch.stack([item[0] for item in batch]) graph = [item[1] for item in batch] acc = torch.stack([item[2] for item in batch]) return [x, graph, acc]