diffusionNAG/MobileNetV3/main_exp/transfer_nag_lib/MetaD2A_mobilenetV3/loader.py
2024-03-15 14:38:51 +00:00

150 lines
4.6 KiB
Python

###########################################################################################
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
###########################################################################################
from __future__ import print_function
import os
import torch
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
def get_meta_train_loader(batch_size, data_path, num_sample, is_pred=False):
dataset = MetaTrainDatabase(data_path, num_sample, is_pred)
print(f'==> The number of tasks for meta-training: {len(dataset)}')
loader = DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=1,
collate_fn=collate_fn)
return loader
def get_meta_test_loader(data_path, data_name, num_class=None, is_pred=False):
dataset = MetaTestDataset(data_path, data_name, num_class)
print(f'==> Meta-Test dataset {data_name}')
loader = DataLoader(dataset=dataset,
batch_size=100,
shuffle=False,
num_workers=1)
return loader
class MetaTrainDatabase(Dataset):
def __init__(self, data_path, num_sample, is_pred=False):
self.mode = 'train'
self.acc_norm = True
self.num_sample = num_sample
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
self.dpath = '{}/{}/processed/'.format(data_path, 'predictor' if is_pred else 'generator')
self.dname = f'database_219152_14.0K'
if not os.path.exists(self.dpath + f'{self.dname}_train.pt'):
raise ValueError('')
database = torch.load(self.dpath + f'{self.dname}.pt')
rand_idx = torch.randperm(len(database))
test_len = int(len(database) * 0.15)
idxlst = {'test': rand_idx[:test_len],
'valid': rand_idx[test_len:2 * test_len],
'train': rand_idx[2 * test_len:]}
for m in ['train', 'valid', 'test']:
acc, graph, cls, net, flops = [], [], [], [], []
for idx in tqdm(idxlst[m].tolist(), desc=f'data-{m}'):
acc.append(database[idx]['top1'])
net.append(database[idx]['net'])
cls.append(database[idx]['class'])
flops.append(database[idx]['flops'])
if m == 'train':
mean = torch.mean(torch.tensor(acc)).item()
std = torch.std(torch.tensor(acc)).item()
torch.save({'acc': acc,
'class': cls,
'net': net,
'flops': flops,
'mean': mean,
'std': std},
self.dpath + f'{self.dname}_{m}.pt')
self.set_mode(self.mode)
def set_mode(self, mode):
self.mode = mode
data = torch.load(self.dpath + f'{self.dname}_{self.mode}.pt')
self.acc = data['acc']
self.cls = data['class']
self.net = data['net']
self.flops = data['flops']
self.mean = data['mean']
self.std = data['std']
def __len__(self):
return len(self.acc)
def __getitem__(self, index):
data = []
classes = self.cls[index]
acc = self.acc[index]
graph = self.net[index]
for i, cls in enumerate(classes):
cx = self.x[cls.item()][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
x = torch.cat(data)
if self.acc_norm:
acc = ((acc - self.mean) / self.std) / 100.0
else:
acc = acc / 100.0
return x, graph, torch.tensor(acc).view(1, 1)
class MetaTestDataset(Dataset):
def __init__(self, data_path, data_name, num_sample, num_class=None):
self.num_sample = num_sample
self.data_name = data_name
if data_name == 'aircraft':
data_name = 'aircraft100'
num_class_dict = {
'cifar100': 100,
'cifar10': 10,
'mnist': 10,
'aircraft100': 30,
'svhn': 10,
'pets': 37
}
# 'aircraft30': 30,
# 'aircraft100': 100,
if num_class is not None:
self.num_class = num_class
else:
self.num_class = num_class_dict[data_name]
self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt'))
def __len__(self):
return 1000000
def __getitem__(self, index):
data = []
classes = list(range(self.num_class))
for cls in classes:
cx = self.x[cls][0]
ridx = torch.randperm(len(cx))
data.append(cx[ridx[:self.num_sample]])
x = torch.cat(data)
return x
def collate_fn(batch):
# x = torch.stack([item[0] for item in batch])
# graph = [item[1] for item in batch]
# acc = torch.stack([item[2] for item in batch])
return batch