################################################## # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 # ############################################################################## import os, sys, time, glob, random, argparse import numpy as np, collections from copy import deepcopy import torch import torch.nn as nn from pathlib import Path lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) from config_utils import load_config, dict2config, configure2str from datasets import get_datasets, SearchDataset from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint, get_optim_scheduler from utils import get_model_infos, obtain_accuracy from log_utils import AverageMeter, time_string, convert_secs2time from models import get_search_spaces from nas_201_api import NASBench201API as API from R_EA import train_and_eval, random_architecture_func def main(xargs, nas_bench): assert torch.cuda.is_available(), 'CUDA is not available.' torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.set_num_threads( xargs.workers ) prepare_seed(xargs.rand_seed) logger = prepare_logger(args) if xargs.dataset == 'cifar10': dataname = 'cifar10-valid' else: dataname = xargs.dataset if xargs.data_path is not None: train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) split_Fpath = 'configs/nas-benchmark/cifar-split.txt' cifar_split = load_config(split_Fpath, None, None) train_split, valid_split = cifar_split.train, cifar_split.valid logger.log('Load split file from {:}'.format(split_Fpath)) config_path = 'configs/nas-benchmark/algos/R-EA.config' config = load_config(config_path, {'class_num': class_num, 'xshape': xshape}, logger) # To split data train_data_v2 = deepcopy(train_data) train_data_v2.transform = valid_data.transform valid_data = train_data_v2 search_data = SearchDataset(xargs.dataset, train_data, train_split, valid_split) # data loader train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split) , num_workers=xargs.workers, pin_memory=True) valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=xargs.workers, pin_memory=True) logger.log('||||||| {:10s} ||||||| Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(train_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) extra_info = {'config': config, 'train_loader': train_loader, 'valid_loader': valid_loader} else: config_path = 'configs/nas-benchmark/algos/R-EA.config' config = load_config(config_path, None, logger) logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) extra_info = {'config': config, 'train_loader': None, 'valid_loader': None} search_space = get_search_spaces('cell', xargs.search_space_name) random_arch = random_architecture_func(xargs.max_nodes, search_space) #x =random_arch() ; y = mutate_arch(x) x_start_time = time.time() logger.log('{:} use nas_bench : {:}'.format(time_string(), nas_bench)) best_arch, best_acc, total_time_cost, history = None, -1, 0, [] #for idx in range(xargs.random_num): while total_time_cost < xargs.time_budget: arch = random_arch() accuracy, cost_time = train_and_eval(arch, nas_bench, extra_info, dataname) if total_time_cost + cost_time > xargs.time_budget: break else: total_time_cost += cost_time history.append(arch) if best_arch is None or best_acc < accuracy: best_acc, best_arch = accuracy, arch logger.log('[{:03d}] : {:} : accuracy = {:.2f}%'.format(len(history), arch, accuracy)) logger.log('{:} best arch is {:}, accuracy = {:.2f}%, visit {:} archs with {:.1f} s (real-cost = {:.3f} s).'.format(time_string(), best_arch, best_acc, len(history), total_time_cost, time.time()-x_start_time)) info = nas_bench.query_by_arch(best_arch, '200') if info is None: logger.log('Did not find this architecture : {:}.'.format(best_arch)) else : logger.log('{:}'.format(info)) logger.log('-'*100) logger.close() return logger.log_dir, nas_bench.query_index_by_arch( best_arch ) if __name__ == '__main__': parser = argparse.ArgumentParser("Random NAS") parser.add_argument('--data_path', type=str, help='Path to dataset') parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.') # channels and number-of-cells parser.add_argument('--search_space_name', type=str, help='The search space name.') parser.add_argument('--max_nodes', type=int, help='The maximum number of nodes.') parser.add_argument('--channel', type=int, help='The number of channels.') parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.') #parser.add_argument('--random_num', type=int, help='The number of random selected architectures.') parser.add_argument('--time_budget', type=int, help='The total time cost budge for searching (in seconds).') # log parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') parser.add_argument('--save_dir', type=str, help='Folder to save checkpoints and log.') parser.add_argument('--arch_nas_dataset', type=str, help='The path to load the architecture dataset (tiny-nas-benchmark).') parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)') parser.add_argument('--rand_seed', type=int, help='manual seed') args = parser.parse_args() #if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000) if args.arch_nas_dataset is None or not os.path.isfile(args.arch_nas_dataset): nas_bench = None else: print ('{:} build NAS-Benchmark-API from {:}'.format(time_string(), args.arch_nas_dataset)) nas_bench = API(args.arch_nas_dataset) if args.rand_seed < 0: save_dir, all_indexes, num = None, [], 500 for i in range(num): print ('{:} : {:03d}/{:03d}'.format(time_string(), i, num)) args.rand_seed = random.randint(1, 100000) save_dir, index = main(args, nas_bench) all_indexes.append( index ) torch.save(all_indexes, save_dir / 'results.pth') else: main(args, nas_bench)