2019-11-15 07:26:32 +01:00
##################################################
2020-01-14 14:52:06 +01:00
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
##############################################################################
2019-11-14 03:55:42 +01:00
import os , sys , time , glob , random , argparse
import numpy as np , collections
from copy import deepcopy
import torch
import torch . nn as nn
from pathlib import Path
lib_dir = ( Path ( __file__ ) . parent / ' .. ' / ' .. ' / ' lib ' ) . resolve ( )
if str ( lib_dir ) not in sys . path : sys . path . insert ( 0 , str ( lib_dir ) )
from config_utils import load_config , dict2config , configure2str
from datasets import get_datasets , SearchDataset
from procedures import prepare_seed , prepare_logger , save_checkpoint , copy_checkpoint , get_optim_scheduler
from utils import get_model_infos , obtain_accuracy
from log_utils import AverageMeter , time_string , convert_secs2time
from models import get_search_spaces
2020-01-14 14:52:06 +01:00
from nas_201_api import NASBench201API as API
2019-11-14 03:55:42 +01:00
from R_EA import train_and_eval , random_architecture_func
2019-11-19 01:58:04 +01:00
def main ( xargs , nas_bench ) :
2019-11-14 03:55:42 +01:00
assert torch . cuda . is_available ( ) , ' CUDA is not available. '
torch . backends . cudnn . enabled = True
torch . backends . cudnn . benchmark = False
torch . backends . cudnn . deterministic = True
torch . set_num_threads ( xargs . workers )
prepare_seed ( xargs . rand_seed )
logger = prepare_logger ( args )
2020-03-15 12:50:17 +01:00
if xargs . dataset == ' cifar10 ' :
dataname = ' cifar10-valid '
else :
dataname = xargs . dataset
2019-12-31 12:02:11 +01:00
if xargs . data_path is not None :
train_data , valid_data , xshape , class_num = get_datasets ( xargs . dataset , xargs . data_path , - 1 )
split_Fpath = ' configs/nas-benchmark/cifar-split.txt '
cifar_split = load_config ( split_Fpath , None , None )
train_split , valid_split = cifar_split . train , cifar_split . valid
logger . log ( ' Load split file from {:} ' . format ( split_Fpath ) )
config_path = ' configs/nas-benchmark/algos/R-EA.config '
config = load_config ( config_path , { ' class_num ' : class_num , ' xshape ' : xshape } , logger )
# To split data
train_data_v2 = deepcopy ( train_data )
train_data_v2 . transform = valid_data . transform
valid_data = train_data_v2
search_data = SearchDataset ( xargs . dataset , train_data , train_split , valid_split )
# data loader
train_loader = torch . utils . data . DataLoader ( train_data , batch_size = config . batch_size , sampler = torch . utils . data . sampler . SubsetRandomSampler ( train_split ) , num_workers = xargs . workers , pin_memory = True )
valid_loader = torch . utils . data . DataLoader ( valid_data , batch_size = config . batch_size , sampler = torch . utils . data . sampler . SubsetRandomSampler ( valid_split ) , num_workers = xargs . workers , pin_memory = True )
logger . log ( ' ||||||| {:10s} ||||||| Train-Loader-Num= {:} , Valid-Loader-Num= {:} , batch size= {:} ' . format ( xargs . dataset , len ( train_loader ) , len ( valid_loader ) , config . batch_size ) )
logger . log ( ' ||||||| {:10s} ||||||| Config= {:} ' . format ( xargs . dataset , config ) )
extra_info = { ' config ' : config , ' train_loader ' : train_loader , ' valid_loader ' : valid_loader }
else :
config_path = ' configs/nas-benchmark/algos/R-EA.config '
config = load_config ( config_path , None , logger )
logger . log ( ' ||||||| {:10s} ||||||| Config= {:} ' . format ( xargs . dataset , config ) )
extra_info = { ' config ' : config , ' train_loader ' : None , ' valid_loader ' : None }
2019-11-14 03:55:42 +01:00
search_space = get_search_spaces ( ' cell ' , xargs . search_space_name )
random_arch = random_architecture_func ( xargs . max_nodes , search_space )
#x =random_arch() ; y = mutate_arch(x)
2020-01-01 12:18:42 +01:00
x_start_time = time . time ( )
2019-11-14 03:55:42 +01:00
logger . log ( ' {:} use nas_bench : {:} ' . format ( time_string ( ) , nas_bench ) )
2019-12-24 07:36:47 +01:00
best_arch , best_acc , total_time_cost , history = None , - 1 , 0 , [ ]
#for idx in range(xargs.random_num):
while total_time_cost < xargs . time_budget :
2019-11-14 03:55:42 +01:00
arch = random_arch ( )
2020-03-15 12:50:17 +01:00
accuracy , cost_time = train_and_eval ( arch , nas_bench , extra_info , dataname )
2019-12-24 07:36:47 +01:00
if total_time_cost + cost_time > xargs . time_budget : break
else : total_time_cost + = cost_time
history . append ( arch )
2019-11-14 03:55:42 +01:00
if best_arch is None or best_acc < accuracy :
best_acc , best_arch = accuracy , arch
2019-12-24 07:36:47 +01:00
logger . log ( ' [ {:03d} ] : {:} : accuracy = {:.2f} % ' . format ( len ( history ) , arch , accuracy ) )
2020-01-01 12:18:42 +01:00
logger . log ( ' {:} best arch is {:} , accuracy = {:.2f} % , visit {:} archs with {:.1f} s (real-cost = {:.3f} s). ' . format ( time_string ( ) , best_arch , best_acc , len ( history ) , total_time_cost , time . time ( ) - x_start_time ) )
2019-11-14 03:55:42 +01:00
2020-07-08 06:46:25 +02:00
info = nas_bench . query_by_arch ( best_arch , ' 200 ' )
2019-11-19 01:58:04 +01:00
if info is None : logger . log ( ' Did not find this architecture : {:} . ' . format ( best_arch ) )
else : logger . log ( ' {:} ' . format ( info ) )
2019-11-14 03:55:42 +01:00
logger . log ( ' - ' * 100 )
logger . close ( )
2019-11-19 01:58:04 +01:00
return logger . log_dir , nas_bench . query_index_by_arch ( best_arch )
2019-12-31 12:02:11 +01:00
2019-11-14 03:55:42 +01:00
if __name__ == ' __main__ ' :
parser = argparse . ArgumentParser ( " Regularized Evolution Algorithm " )
parser . add_argument ( ' --data_path ' , type = str , help = ' Path to dataset ' )
parser . add_argument ( ' --dataset ' , type = str , choices = [ ' cifar10 ' , ' cifar100 ' , ' ImageNet16-120 ' ] , help = ' Choose between Cifar10/100 and ImageNet-16. ' )
# channels and number-of-cells
parser . add_argument ( ' --search_space_name ' , type = str , help = ' The search space name. ' )
parser . add_argument ( ' --max_nodes ' , type = int , help = ' The maximum number of nodes. ' )
parser . add_argument ( ' --channel ' , type = int , help = ' The number of channels. ' )
parser . add_argument ( ' --num_cells ' , type = int , help = ' The number of cells in one stage. ' )
2019-12-24 07:36:47 +01:00
#parser.add_argument('--random_num', type=int, help='The number of random selected architectures.')
parser . add_argument ( ' --time_budget ' , type = int , help = ' The total time cost budge for searching (in seconds). ' )
2019-11-14 03:55:42 +01:00
# log
parser . add_argument ( ' --workers ' , type = int , default = 2 , help = ' number of data loading workers (default: 2) ' )
parser . add_argument ( ' --save_dir ' , type = str , help = ' Folder to save checkpoints and log. ' )
parser . add_argument ( ' --arch_nas_dataset ' , type = str , help = ' The path to load the architecture dataset (tiny-nas-benchmark). ' )
parser . add_argument ( ' --print_freq ' , type = int , help = ' print frequency (default: 200) ' )
parser . add_argument ( ' --rand_seed ' , type = int , help = ' manual seed ' )
args = parser . parse_args ( )
2019-11-19 01:58:04 +01:00
#if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
if args . arch_nas_dataset is None or not os . path . isfile ( args . arch_nas_dataset ) :
nas_bench = None
else :
print ( ' {:} build NAS-Benchmark-API from {:} ' . format ( time_string ( ) , args . arch_nas_dataset ) )
2019-12-20 10:41:49 +01:00
nas_bench = API ( args . arch_nas_dataset )
2019-11-19 01:58:04 +01:00
if args . rand_seed < 0 :
save_dir , all_indexes , num = None , [ ] , 500
for i in range ( num ) :
print ( ' {:} : {:03d} / {:03d} ' . format ( time_string ( ) , i , num ) )
args . rand_seed = random . randint ( 1 , 100000 )
save_dir , index = main ( args , nas_bench )
all_indexes . append ( index )
torch . save ( all_indexes , save_dir / ' results.pth ' )
else :
main ( args , nas_bench )