2019-11-15 07:26:32 +01:00
##################################################
2020-01-14 14:52:06 +01:00
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
2019-11-15 07:26:32 +01:00
##################################################################
# Regularized Evolution for Image Classifier Architecture Search #
##################################################################
2019-11-14 03:55:42 +01:00
import os , sys , time , glob , random , argparse
import numpy as np , collections
from copy import deepcopy
import torch
import torch . nn as nn
from pathlib import Path
lib_dir = ( Path ( __file__ ) . parent / ' .. ' / ' .. ' / ' lib ' ) . resolve ( )
if str ( lib_dir ) not in sys . path : sys . path . insert ( 0 , str ( lib_dir ) )
from config_utils import load_config , dict2config , configure2str
from datasets import get_datasets , SearchDataset
from procedures import prepare_seed , prepare_logger , save_checkpoint , copy_checkpoint , get_optim_scheduler
from utils import get_model_infos , obtain_accuracy
from log_utils import AverageMeter , time_string , convert_secs2time
2020-01-14 14:52:06 +01:00
from nas_201_api import NASBench201API as API
2019-11-14 03:55:42 +01:00
from models import CellStructure , get_search_spaces
class Model ( object ) :
def __init__ ( self ) :
self . arch = None
self . accuracy = None
def __str__ ( self ) :
""" Prints a readable version of this bitstring. """
return ' {:} ' . format ( self . arch )
2020-01-14 14:52:06 +01:00
# This function is to mimic the training and evaluatinig procedure for a single architecture `arch`.
# The time_cost is calculated as the total training time for a few (e.g., 12 epochs) plus the evaluation time for one epoch.
2019-11-14 03:55:42 +01:00
def train_and_eval ( arch , nas_bench , extra_info ) :
if nas_bench is not None :
arch_index = nas_bench . query_index_by_arch ( arch )
assert arch_index > = 0 , ' can not find this arch : {:} ' . format ( arch )
2019-12-28 05:42:36 +01:00
info = nas_bench . get_more_info ( arch_index , ' cifar10-valid ' , None , True )
2019-12-24 07:36:47 +01:00
valid_acc , time_cost = info [ ' valid-accuracy ' ] , info [ ' train-all-time ' ] + info [ ' valid-per-time ' ]
2019-12-23 01:19:09 +01:00
#_, valid_acc = info.get_metrics('cifar10-valid', 'x-valid' , 25, True) # use the validation accuracy after 25 training epochs
2019-11-14 03:55:42 +01:00
else :
# train a model from scratch.
raise ValueError ( ' NOT IMPLEMENT YET ' )
2019-12-24 07:36:47 +01:00
return valid_acc , time_cost
2019-11-14 03:55:42 +01:00
def random_architecture_func ( max_nodes , op_names ) :
# return a random architecture
def random_architecture ( ) :
genotypes = [ ]
for i in range ( 1 , max_nodes ) :
xlist = [ ]
for j in range ( i ) :
node_str = ' {:} <- {:} ' . format ( i , j )
op_name = random . choice ( op_names )
xlist . append ( ( op_name , j ) )
genotypes . append ( tuple ( xlist ) )
return CellStructure ( genotypes )
return random_architecture
def mutate_arch_func ( op_names ) :
""" Computes the architecture for a child of the given parent architecture.
The parent architecture is cloned and mutated to produce the child architecture . The child architecture is mutated by randomly switch one operation to another .
"""
def mutate_arch_func ( parent_arch ) :
child_arch = deepcopy ( parent_arch )
node_id = random . randint ( 0 , len ( child_arch . nodes ) - 1 )
node_info = list ( child_arch . nodes [ node_id ] )
snode_id = random . randint ( 0 , len ( node_info ) - 1 )
xop = random . choice ( op_names )
while xop == node_info [ snode_id ] [ 0 ] :
xop = random . choice ( op_names )
node_info [ snode_id ] = ( xop , node_info [ snode_id ] [ 1 ] )
child_arch . nodes [ node_id ] = tuple ( node_info )
return child_arch
return mutate_arch_func
2019-12-24 07:36:47 +01:00
def regularized_evolution ( cycles , population_size , sample_size , time_budget , random_arch , mutate_arch , nas_bench , extra_info ) :
2019-11-14 03:55:42 +01:00
""" Algorithm for regularized evolution (i.e. aging evolution).
Follows " Algorithm 1 " in Real et al . " Regularized Evolution for Image
Classifier Architecture Search " .
Args :
cycles : the number of cycles the algorithm should run for .
population_size : the number of individuals to keep in the population .
sample_size : the number of individuals that should participate in each tournament .
2019-12-24 07:36:47 +01:00
time_budget : the upper bound of searching cost
2019-11-14 03:55:42 +01:00
Returns :
history : a list of ` Model ` instances , representing all the models computed
during the evolution experiment .
"""
population = collections . deque ( )
2019-12-24 07:36:47 +01:00
history , total_time_cost = [ ] , 0 # Not used by the algorithm, only used to report results.
2019-11-14 03:55:42 +01:00
# Initialize the population with random models.
while len ( population ) < population_size :
model = Model ( )
model . arch = random_arch ( )
2019-12-24 07:36:47 +01:00
model . accuracy , time_cost = train_and_eval ( model . arch , nas_bench , extra_info )
2019-11-14 03:55:42 +01:00
population . append ( model )
history . append ( model )
2019-12-24 07:36:47 +01:00
total_time_cost + = time_cost
2019-11-14 03:55:42 +01:00
# Carry out evolution in cycles. Each cycle produces a model and removes
# another.
2019-12-24 07:36:47 +01:00
#while len(history) < cycles:
while total_time_cost < time_budget :
2019-11-14 03:55:42 +01:00
# Sample randomly chosen models from the current population.
2019-12-24 07:36:47 +01:00
start_time , sample = time . time ( ) , [ ]
2019-11-14 03:55:42 +01:00
while len ( sample ) < sample_size :
# Inefficient, but written this way for clarity. In the case of neural
# nets, the efficiency of this line is irrelevant because training neural
# nets is the rate-determining step.
candidate = random . choice ( list ( population ) )
sample . append ( candidate )
# The parent is the best model in the sample.
parent = max ( sample , key = lambda i : i . accuracy )
# Create the child model and store it.
child = Model ( )
child . arch = mutate_arch ( parent . arch )
2019-12-24 07:36:47 +01:00
total_time_cost + = time . time ( ) - start_time
child . accuracy , time_cost = train_and_eval ( child . arch , nas_bench , extra_info )
if total_time_cost + time_cost > time_budget : # return
return history , total_time_cost
else :
total_time_cost + = time_cost
2019-11-14 03:55:42 +01:00
population . append ( child )
history . append ( child )
# Remove the oldest model.
population . popleft ( )
2019-12-24 07:36:47 +01:00
return history , total_time_cost
2019-11-14 03:55:42 +01:00
2019-11-19 01:58:04 +01:00
def main ( xargs , nas_bench ) :
2019-11-14 03:55:42 +01:00
assert torch . cuda . is_available ( ) , ' CUDA is not available. '
torch . backends . cudnn . enabled = True
torch . backends . cudnn . benchmark = False
torch . backends . cudnn . deterministic = True
torch . set_num_threads ( xargs . workers )
prepare_seed ( xargs . rand_seed )
logger = prepare_logger ( args )
assert xargs . dataset == ' cifar10 ' , ' currently only support CIFAR-10 '
2019-12-31 12:02:11 +01:00
if xargs . data_path is not None :
train_data , valid_data , xshape , class_num = get_datasets ( xargs . dataset , xargs . data_path , - 1 )
split_Fpath = ' configs/nas-benchmark/cifar-split.txt '
cifar_split = load_config ( split_Fpath , None , None )
train_split , valid_split = cifar_split . train , cifar_split . valid
logger . log ( ' Load split file from {:} ' . format ( split_Fpath ) )
config_path = ' configs/nas-benchmark/algos/R-EA.config '
config = load_config ( config_path , { ' class_num ' : class_num , ' xshape ' : xshape } , logger )
# To split data
train_data_v2 = deepcopy ( train_data )
train_data_v2 . transform = valid_data . transform
valid_data = train_data_v2
search_data = SearchDataset ( xargs . dataset , train_data , train_split , valid_split )
# data loader
train_loader = torch . utils . data . DataLoader ( train_data , batch_size = config . batch_size , sampler = torch . utils . data . sampler . SubsetRandomSampler ( train_split ) , num_workers = xargs . workers , pin_memory = True )
valid_loader = torch . utils . data . DataLoader ( valid_data , batch_size = config . batch_size , sampler = torch . utils . data . sampler . SubsetRandomSampler ( valid_split ) , num_workers = xargs . workers , pin_memory = True )
logger . log ( ' ||||||| {:10s} ||||||| Train-Loader-Num= {:} , Valid-Loader-Num= {:} , batch size= {:} ' . format ( xargs . dataset , len ( train_loader ) , len ( valid_loader ) , config . batch_size ) )
logger . log ( ' ||||||| {:10s} ||||||| Config= {:} ' . format ( xargs . dataset , config ) )
extra_info = { ' config ' : config , ' train_loader ' : train_loader , ' valid_loader ' : valid_loader }
else :
config_path = ' configs/nas-benchmark/algos/R-EA.config '
config = load_config ( config_path , None , logger )
logger . log ( ' ||||||| {:10s} ||||||| Config= {:} ' . format ( xargs . dataset , config ) )
extra_info = { ' config ' : config , ' train_loader ' : None , ' valid_loader ' : None }
2019-11-14 03:55:42 +01:00
search_space = get_search_spaces ( ' cell ' , xargs . search_space_name )
random_arch = random_architecture_func ( xargs . max_nodes , search_space )
mutate_arch = mutate_arch_func ( search_space )
#x =random_arch() ; y = mutate_arch(x)
2020-01-01 12:18:42 +01:00
x_start_time = time . time ( )
2019-11-14 03:55:42 +01:00
logger . log ( ' {:} use nas_bench : {:} ' . format ( time_string ( ) , nas_bench ) )
2019-12-24 07:36:47 +01:00
logger . log ( ' - ' * 30 + ' start searching with the time budget of {:} s ' . format ( xargs . time_budget ) )
history , total_cost = regularized_evolution ( xargs . ea_cycles , xargs . ea_population , xargs . ea_sample_size , xargs . time_budget , random_arch , mutate_arch , nas_bench if args . ea_fast_by_api else None , extra_info )
2020-01-01 12:18:42 +01:00
logger . log ( ' {:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost= {:.2f} s). ' . format ( time_string ( ) , len ( history ) , total_cost , time . time ( ) - x_start_time ) )
2019-11-14 03:55:42 +01:00
best_arch = max ( history , key = lambda i : i . accuracy )
best_arch = best_arch . arch
logger . log ( ' {:} best arch is {:} ' . format ( time_string ( ) , best_arch ) )
2019-11-19 01:58:04 +01:00
info = nas_bench . query_by_arch ( best_arch )
if info is None : logger . log ( ' Did not find this architecture : {:} . ' . format ( best_arch ) )
else : logger . log ( ' {:} ' . format ( info ) )
2019-11-14 03:55:42 +01:00
logger . log ( ' - ' * 100 )
logger . close ( )
2019-11-19 01:58:04 +01:00
return logger . log_dir , nas_bench . query_index_by_arch ( best_arch )
2019-11-14 03:55:42 +01:00
if __name__ == ' __main__ ' :
parser = argparse . ArgumentParser ( " Regularized Evolution Algorithm " )
parser . add_argument ( ' --data_path ' , type = str , help = ' Path to dataset ' )
parser . add_argument ( ' --dataset ' , type = str , choices = [ ' cifar10 ' , ' cifar100 ' , ' ImageNet16-120 ' ] , help = ' Choose between Cifar10/100 and ImageNet-16. ' )
# channels and number-of-cells
parser . add_argument ( ' --search_space_name ' , type = str , help = ' The search space name. ' )
parser . add_argument ( ' --max_nodes ' , type = int , help = ' The maximum number of nodes. ' )
parser . add_argument ( ' --channel ' , type = int , help = ' The number of channels. ' )
parser . add_argument ( ' --num_cells ' , type = int , help = ' The number of cells in one stage. ' )
parser . add_argument ( ' --ea_cycles ' , type = int , help = ' The number of cycles in EA. ' )
parser . add_argument ( ' --ea_population ' , type = int , help = ' The population size in EA. ' )
parser . add_argument ( ' --ea_sample_size ' , type = int , help = ' The sample size in EA. ' )
parser . add_argument ( ' --ea_fast_by_api ' , type = int , help = ' Use our API to speed up the experiments or not. ' )
2019-12-24 07:36:47 +01:00
parser . add_argument ( ' --time_budget ' , type = int , help = ' The total time cost budge for searching (in seconds). ' )
2019-11-14 03:55:42 +01:00
# log
parser . add_argument ( ' --workers ' , type = int , default = 2 , help = ' number of data loading workers (default: 2) ' )
parser . add_argument ( ' --save_dir ' , type = str , help = ' Folder to save checkpoints and log. ' )
parser . add_argument ( ' --arch_nas_dataset ' , type = str , help = ' The path to load the architecture dataset (tiny-nas-benchmark). ' )
parser . add_argument ( ' --print_freq ' , type = int , help = ' print frequency (default: 200) ' )
2019-11-19 01:58:04 +01:00
parser . add_argument ( ' --rand_seed ' , type = int , default = - 1 , help = ' manual seed ' )
2019-11-14 03:55:42 +01:00
args = parser . parse_args ( )
2019-11-19 01:58:04 +01:00
#if args.rand_seed is None or args.rand_seed < 0: args.rand_seed = random.randint(1, 100000)
2019-11-14 03:55:42 +01:00
args . ea_fast_by_api = args . ea_fast_by_api > 0
2019-11-19 01:58:04 +01:00
if args . arch_nas_dataset is None or not os . path . isfile ( args . arch_nas_dataset ) :
nas_bench = None
else :
print ( ' {:} build NAS-Benchmark-API from {:} ' . format ( time_string ( ) , args . arch_nas_dataset ) )
2019-12-20 10:41:49 +01:00
nas_bench = API ( args . arch_nas_dataset )
2019-11-19 01:58:04 +01:00
if args . rand_seed < 0 :
save_dir , all_indexes , num = None , [ ] , 500
for i in range ( num ) :
print ( ' {:} : {:03d} / {:03d} ' . format ( time_string ( ) , i , num ) )
args . rand_seed = random . randint ( 1 , 100000 )
save_dir , index = main ( args , nas_bench )
all_indexes . append ( index )
torch . save ( all_indexes , save_dir / ' results.pth ' )
else :
main ( args , nas_bench )