simplify baselines
This commit is contained in:
		| @@ -110,25 +110,30 @@ def main(xargs, nas_bench): | ||||
|   logger = prepare_logger(args) | ||||
|  | ||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' | ||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|   split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|   cifar_split = load_config(split_Fpath, None, None) | ||||
|   train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|   logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|   config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|   config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|   # To split data | ||||
|   train_data_v2 = deepcopy(train_data) | ||||
|   train_data_v2.transform = valid_data.transform | ||||
|   valid_data    = train_data_v2 | ||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|   # data loader | ||||
|   train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) | ||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|   logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|   extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||
|    | ||||
|   if xargs.data_path is not None: | ||||
|     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|     cifar_split = load_config(split_Fpath, None, None) | ||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|     config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|     # To split data | ||||
|     train_data_v2 = deepcopy(train_data) | ||||
|     train_data_v2.transform = valid_data.transform | ||||
|     valid_data    = train_data_v2 | ||||
|     search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|     # data loader | ||||
|     train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|     logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|     extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||
|   else: | ||||
|     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|     config = load_config(config_path, None, logger) | ||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|     extra_info = {'config': config, 'train_loader': None, 'valid_loader': None} | ||||
|  | ||||
|   # nas dataset load | ||||
|   assert xargs.arch_nas_dataset is not None and os.path.isfile(xargs.arch_nas_dataset) | ||||
|   | ||||
| @@ -29,25 +29,30 @@ def main(xargs, nas_bench): | ||||
|   logger = prepare_logger(args) | ||||
|  | ||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' | ||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|   split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|   cifar_split = load_config(split_Fpath, None, None) | ||||
|   train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|   logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|   config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|   config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|   # To split data | ||||
|   train_data_v2 = deepcopy(train_data) | ||||
|   train_data_v2.transform = valid_data.transform | ||||
|   valid_data    = train_data_v2 | ||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|   # data loader | ||||
|   train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) | ||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|   logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|   extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||
|  | ||||
|   if xargs.data_path is not None: | ||||
|     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|     cifar_split = load_config(split_Fpath, None, None) | ||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|     config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|     # To split data | ||||
|     train_data_v2 = deepcopy(train_data) | ||||
|     train_data_v2.transform = valid_data.transform | ||||
|     valid_data    = train_data_v2 | ||||
|     search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|     # data loader | ||||
|     train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|     logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|     extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||
|   else: | ||||
|     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|     config = load_config(config_path, None, logger) | ||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|     extra_info = {'config': config, 'train_loader': None, 'valid_loader': None} | ||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||
|   random_arch = random_architecture_func(xargs.max_nodes, search_space) | ||||
|   #x =random_arch() ; y = mutate_arch(x) | ||||
| @@ -71,7 +76,7 @@ def main(xargs, nas_bench): | ||||
|   logger.log('-'*100) | ||||
|   logger.close() | ||||
|   return logger.log_dir, nas_bench.query_index_by_arch( best_arch ) | ||||
|    | ||||
|  | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   | ||||
| @@ -172,24 +172,30 @@ def main(xargs, nas_bench): | ||||
|   logger = prepare_logger(args) | ||||
|  | ||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' | ||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|   split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|   cifar_split = load_config(split_Fpath, None, None) | ||||
|   train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|   logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|   config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|   config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|   # To split data | ||||
|   train_data_v2 = deepcopy(train_data) | ||||
|   train_data_v2.transform = valid_data.transform | ||||
|   valid_data    = train_data_v2 | ||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|   # data loader | ||||
|   train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) | ||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|   logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|   extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||
|   if xargs.data_path is not None: | ||||
|     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|     cifar_split = load_config(split_Fpath, None, None) | ||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|     config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|     # To split data | ||||
|     train_data_v2 = deepcopy(train_data) | ||||
|     train_data_v2.transform = valid_data.transform | ||||
|     valid_data    = train_data_v2 | ||||
|     search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|     # data loader | ||||
|     train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|     logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|     extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||
|   else: | ||||
|     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|     config = load_config(config_path, None, logger) | ||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|     extra_info = {'config': config, 'train_loader': None, 'valid_loader': None} | ||||
|  | ||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||
|   random_arch = random_architecture_func(xargs.max_nodes, search_space) | ||||
|   | ||||
| @@ -99,24 +99,31 @@ def main(xargs, nas_bench): | ||||
|   logger = prepare_logger(args) | ||||
|  | ||||
|   assert xargs.dataset == 'cifar10', 'currently only support CIFAR-10' | ||||
|   train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|   split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|   cifar_split = load_config(split_Fpath, None, None) | ||||
|   train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|   logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|   config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|   config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|   # To split data | ||||
|   train_data_v2 = deepcopy(train_data) | ||||
|   train_data_v2.transform = valid_data.transform | ||||
|   valid_data    = train_data_v2 | ||||
|   search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|   # data loader | ||||
|   train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) | ||||
|   valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|   logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||
|   logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|   extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||
|   if xargs.data_path is not None: | ||||
|     train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) | ||||
|     split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|     cifar_split = load_config(split_Fpath, None, None) | ||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid | ||||
|     logger.log('Load split file from {:}'.format(split_Fpath)) | ||||
|     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|     config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) | ||||
|     # To split data | ||||
|     train_data_v2 = deepcopy(train_data) | ||||
|     train_data_v2.transform = valid_data.transform | ||||
|     valid_data    = train_data_v2 | ||||
|     search_data   = SearchDataset(xargs.dataset, train_data, train_split, valid_split) | ||||
|     # data loader | ||||
|     train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) | ||||
|     logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) | ||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|     extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} | ||||
|   else: | ||||
|     config_path = 'configs/nas-benchmark/algos/R-EA.config' | ||||
|     config = load_config(config_path, None, logger) | ||||
|     extra_info = {'config': config, 'train_loader': None, 'valid_loader': None} | ||||
|     logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) | ||||
|    | ||||
|    | ||||
|   search_space = get_search_spaces('cell', xargs.search_space_name) | ||||
|   policy    = Policy(xargs.max_nodes, search_space) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user