NAS-sharing-parameters support 3 datasets / update ops / update pypi
This commit is contained in:
		| @@ -1,5 +1,5 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| from .get_dataset_with_transform import get_datasets | ||||
| from .get_dataset_with_transform import get_datasets, get_nas_search_loaders | ||||
| from .SearchDatasetWrap import SearchDataset | ||||
|   | ||||
| @@ -6,8 +6,12 @@ import os.path as osp | ||||
| import numpy as np | ||||
| import torchvision.datasets as dset | ||||
| import torchvision.transforms as transforms | ||||
| from copy import deepcopy | ||||
| from PIL import Image | ||||
|  | ||||
| from .DownsampledImageNet import ImageNet16 | ||||
| from .SearchDatasetWrap import SearchDataset | ||||
| from config_utils import load_config | ||||
|  | ||||
|  | ||||
| Dataset2Class = {'cifar10' : 10, | ||||
| @@ -177,6 +181,47 @@ def get_datasets(name, root, cutout): | ||||
|   class_num = Dataset2Class[name] | ||||
|   return train_data, test_data, xshape, class_num | ||||
|  | ||||
|  | ||||
| def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers): | ||||
|   if isinstance(batch_size, (list,tuple)): | ||||
|     batch, test_batch = batch_size | ||||
|   else: | ||||
|     batch, test_batch = batch_size, batch_size | ||||
|   if dataset == 'cifar10': | ||||
|     #split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|     cifar_split = load_config('{:}/cifar-split.txt'.format(config_root), None, None) | ||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set | ||||
|     #logger.log('Load split file from {:}'.format(split_Fpath))      # they are two disjoint groups in the original CIFAR-10 training set | ||||
|     # To split data | ||||
|     xvalid_data  = deepcopy(train_data) | ||||
|     if hasattr(xvalid_data, 'transforms'): # to avoid a print issue | ||||
|       xvalid_data.transforms = valid_data.transform | ||||
|     xvalid_data.transform  = deepcopy( valid_data.transform ) | ||||
|     search_data   = SearchDataset(dataset, train_data, train_split, valid_split) | ||||
|     # data loader | ||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||
|     train_loader  = torch.utils.data.DataLoader(train_data , batch_size=batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=workers, pin_memory=True) | ||||
|   elif dataset == 'cifar100': | ||||
|     cifar100_test_split = load_config('{:}/cifar100-test-split.txt'.format(config_root), None, None) | ||||
|     search_train_data = train_data | ||||
|     search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform | ||||
|     search_data   = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), cifar100_test_split.xvalid) | ||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||
|     train_loader  = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_test_split.xvalid), num_workers=workers, pin_memory=True) | ||||
|   elif dataset == 'ImageNet16-120': | ||||
|     imagenet_test_split = load_config('{:}/imagenet-16-120-test-split.txt'.format(config_root), None, None) | ||||
|     search_train_data = train_data | ||||
|     search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform | ||||
|     search_data   = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid) | ||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||
|     train_loader  = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=workers, pin_memory=True) | ||||
|   else: | ||||
|     raise ValueError('invalid dataset : {:}'.format(dataset)) | ||||
|   return search_loader, train_loader, valid_loader | ||||
|  | ||||
| #if __name__ == '__main__': | ||||
| #  train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1) | ||||
| #  import pdb; pdb.set_trace() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user