2020-07-13 04:53:11 +02:00
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
##################################################################
# Regularized Evolution for Image Classifier Architecture Search #
##################################################################
2020-08-30 10:04:52 +02:00
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar10 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/NATS-algos/regularized_ea.py --dataset ImageNet16-120 --search_space tss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar10 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/NATS-algos/regularized_ea.py --dataset cifar100 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
# python ./exps/NATS-algos/regularized_ea.py --dataset ImageNet16-120 --search_space sss --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --rand_seed 1
2020-12-01 15:25:23 +01:00
# python ./exps/NATS-algos/regularized_ea.py --dataset ${dataset} --search_space ${search_space} --time_budget ${time_budget} --ea_cycles 200 --ea_population 10 --ea_sample_size 3 --use_proxy 0
2020-07-13 05:43:10 +02:00
##################################################################
2020-07-13 04:53:11 +02:00
import os , sys , time , glob , random , argparse
import numpy as np , collections
from copy import deepcopy
import torch
import torch . nn as nn
from pathlib import Path
lib_dir = ( Path ( __file__ ) . parent / ' .. ' / ' .. ' / ' lib ' ) . resolve ( )
if str ( lib_dir ) not in sys . path : sys . path . insert ( 0 , str ( lib_dir ) )
from config_utils import load_config , dict2config , configure2str
from datasets import get_datasets , SearchDataset
from procedures import prepare_seed , prepare_logger , save_checkpoint , copy_checkpoint , get_optim_scheduler
from utils import get_model_infos , obtain_accuracy
from log_utils import AverageMeter , time_string , convert_secs2time
from models import CellStructure , get_search_spaces
2020-07-30 15:07:11 +02:00
from nats_bench import create
2020-07-13 04:53:11 +02:00
class Model ( object ) :
def __init__ ( self ) :
self . arch = None
self . accuracy = None
def __str__ ( self ) :
""" Prints a readable version of this bitstring. """
return ' {:} ' . format ( self . arch )
def random_topology_func ( op_names , max_nodes = 4 ) :
# Return a random architecture
def random_architecture ( ) :
genotypes = [ ]
for i in range ( 1 , max_nodes ) :
xlist = [ ]
for j in range ( i ) :
node_str = ' {:} <- {:} ' . format ( i , j )
op_name = random . choice ( op_names )
xlist . append ( ( op_name , j ) )
genotypes . append ( tuple ( xlist ) )
return CellStructure ( genotypes )
return random_architecture
def random_size_func ( info ) :
# Return a random architecture
def random_architecture ( ) :
channels = [ ]
for i in range ( info [ ' numbers ' ] ) :
channels . append (
str ( random . choice ( info [ ' candidates ' ] ) ) )
return ' : ' . join ( channels )
return random_architecture
def mutate_topology_func ( op_names ) :
""" Computes the architecture for a child of the given parent architecture.
The parent architecture is cloned and mutated to produce the child architecture . The child architecture is mutated by randomly switch one operation to another .
"""
def mutate_topology_func ( parent_arch ) :
child_arch = deepcopy ( parent_arch )
node_id = random . randint ( 0 , len ( child_arch . nodes ) - 1 )
node_info = list ( child_arch . nodes [ node_id ] )
snode_id = random . randint ( 0 , len ( node_info ) - 1 )
xop = random . choice ( op_names )
while xop == node_info [ snode_id ] [ 0 ] :
xop = random . choice ( op_names )
node_info [ snode_id ] = ( xop , node_info [ snode_id ] [ 1 ] )
child_arch . nodes [ node_id ] = tuple ( node_info )
return child_arch
return mutate_topology_func
def mutate_size_func ( info ) :
""" Computes the architecture for a child of the given parent architecture.
The parent architecture is cloned and mutated to produce the child architecture . The child architecture is mutated by randomly switch one operation to another .
"""
def mutate_size_func ( parent_arch ) :
child_arch = deepcopy ( parent_arch )
child_arch = child_arch . split ( ' : ' )
index = random . randint ( 0 , len ( child_arch ) - 1 )
child_arch [ index ] = str ( random . choice ( info [ ' candidates ' ] ) )
return ' : ' . join ( child_arch )
return mutate_size_func
2020-12-01 05:34:00 +01:00
def regularized_evolution ( cycles , population_size , sample_size , time_budget , random_arch , mutate_arch , api , use_proxy , dataset ) :
2020-07-13 04:53:11 +02:00
""" Algorithm for regularized evolution (i.e. aging evolution).
Follows " Algorithm 1 " in Real et al . " Regularized Evolution for Image
Classifier Architecture Search " .
Args :
cycles : the number of cycles the algorithm should run for .
population_size : the number of individuals to keep in the population .
sample_size : the number of individuals that should participate in each tournament .
time_budget : the upper bound of searching cost
Returns :
history : a list of ` Model ` instances , representing all the models computed
during the evolution experiment .
"""
population = collections . deque ( )
api . reset_time ( )
history , total_time_cost = [ ] , [ ] # Not used by the algorithm, only used to report results.
2020-07-13 13:35:13 +02:00
current_best_index = [ ]
2020-07-13 04:53:11 +02:00
# Initialize the population with random models.
while len ( population ) < population_size :
model = Model ( )
model . arch = random_arch ( )
2020-12-01 15:25:23 +01:00
model . accuracy , _ , _ , total_cost = api . simulate_train_eval (
model . arch , dataset , hp = ' 12 ' if use_proxy else api . full_train_epochs )
2020-07-13 04:53:11 +02:00
# Append the info
population . append ( model )
2020-07-13 13:35:13 +02:00
history . append ( ( model . accuracy , model . arch ) )
2020-07-13 04:53:11 +02:00
total_time_cost . append ( total_cost )
2020-07-13 13:35:13 +02:00
current_best_index . append ( api . query_index_by_arch ( max ( history , key = lambda x : x [ 0 ] ) [ 1 ] ) )
2020-07-13 04:53:11 +02:00
# Carry out evolution in cycles. Each cycle produces a model and removes another.
while total_time_cost [ - 1 ] < time_budget :
# Sample randomly chosen models from the current population.
start_time , sample = time . time ( ) , [ ]
while len ( sample ) < sample_size :
# Inefficient, but written this way for clarity. In the case of neural
# nets, the efficiency of this line is irrelevant because training neural
# nets is the rate-determining step.
candidate = random . choice ( list ( population ) )
sample . append ( candidate )
# The parent is the best model in the sample.
parent = max ( sample , key = lambda i : i . accuracy )
# Create the child model and store it.
child = Model ( )
child . arch = mutate_arch ( parent . arch )
2020-12-01 15:25:23 +01:00
child . accuracy , _ , _ , total_cost = api . simulate_train_eval (
child . arch , dataset , hp = ' 12 ' if use_proxy else api . full_train_epochs )
2020-07-13 04:53:11 +02:00
# Append the info
population . append ( child )
2020-07-13 13:35:13 +02:00
history . append ( ( child . accuracy , child . arch ) )
current_best_index . append ( api . query_index_by_arch ( max ( history , key = lambda x : x [ 0 ] ) [ 1 ] ) )
2020-07-13 04:53:11 +02:00
total_time_cost . append ( total_cost )
# Remove the oldest model.
population . popleft ( )
2020-07-13 13:35:13 +02:00
return history , current_best_index , total_time_cost
2020-07-13 04:53:11 +02:00
def main ( xargs , api ) :
2020-07-13 12:04:52 +02:00
torch . set_num_threads ( 4 )
2020-07-13 04:53:11 +02:00
prepare_seed ( xargs . rand_seed )
logger = prepare_logger ( args )
2020-09-16 11:04:22 +02:00
search_space = get_search_spaces ( xargs . search_space , ' nats-bench ' )
2020-07-13 04:53:11 +02:00
if xargs . search_space == ' tss ' :
random_arch = random_topology_func ( search_space )
mutate_arch = mutate_topology_func ( search_space )
else :
random_arch = random_size_func ( search_space )
mutate_arch = mutate_size_func ( search_space )
x_start_time = time . time ( )
logger . log ( ' {:} use api : {:} ' . format ( time_string ( ) , api ) )
logger . log ( ' - ' * 30 + ' start searching with the time budget of {:} s ' . format ( xargs . time_budget ) )
2020-12-01 05:34:00 +01:00
history , current_best_index , total_times = regularized_evolution ( xargs . ea_cycles ,
xargs . ea_population ,
xargs . ea_sample_size ,
xargs . time_budget ,
random_arch , mutate_arch , api , xargs . use_proxy > 0 , xargs . dataset )
2020-07-13 04:53:11 +02:00
logger . log ( ' {:} regularized_evolution finish with history of {:} arch with {:.1f} s (real-cost= {:.2f} s). ' . format ( time_string ( ) , len ( history ) , total_times [ - 1 ] , time . time ( ) - x_start_time ) )
2020-07-14 08:10:34 +02:00
best_arch = max ( history , key = lambda x : x [ 0 ] ) [ 1 ]
2020-07-13 04:53:11 +02:00
logger . log ( ' {:} best arch is {:} ' . format ( time_string ( ) , best_arch ) )
info = api . query_info_str_by_arch ( best_arch , ' 200 ' if xargs . search_space == ' tss ' else ' 90 ' )
logger . log ( ' {:} ' . format ( info ) )
logger . log ( ' - ' * 100 )
logger . close ( )
2020-07-13 13:35:13 +02:00
return logger . log_dir , current_best_index , total_times
2020-07-13 04:53:11 +02:00
if __name__ == ' __main__ ' :
parser = argparse . ArgumentParser ( " Regularized Evolution Algorithm " )
parser . add_argument ( ' --dataset ' , type = str , choices = [ ' cifar10 ' , ' cifar100 ' , ' ImageNet16-120 ' ] , help = ' Choose between Cifar10/100 and ImageNet-16. ' )
parser . add_argument ( ' --search_space ' , type = str , choices = [ ' tss ' , ' sss ' ] , help = ' Choose the search space. ' )
2020-12-01 05:34:00 +01:00
# hyperparameters for REA
2020-07-13 04:53:11 +02:00
parser . add_argument ( ' --ea_cycles ' , type = int , help = ' The number of cycles in EA. ' )
parser . add_argument ( ' --ea_population ' , type = int , help = ' The population size in EA. ' )
parser . add_argument ( ' --ea_sample_size ' , type = int , help = ' The sample size in EA. ' )
2020-07-13 12:04:52 +02:00
parser . add_argument ( ' --time_budget ' , type = int , default = 20000 , help = ' The total time cost budge for searching (in seconds). ' )
2020-12-01 05:34:00 +01:00
parser . add_argument ( ' --use_proxy ' , type = int , default = 1 , help = ' Whether to use the proxy (H0) task or not. ' )
#
2020-07-13 12:04:52 +02:00
parser . add_argument ( ' --loops_if_rand ' , type = int , default = 500 , help = ' The total runs for evaluation. ' )
2020-07-13 04:53:11 +02:00
# log
parser . add_argument ( ' --save_dir ' , type = str , default = ' ./output/search ' , help = ' Folder to save checkpoints and log. ' )
2020-07-13 12:04:52 +02:00
parser . add_argument ( ' --rand_seed ' , type = int , default = - 1 , help = ' manual seed ' )
2020-07-13 04:53:11 +02:00
args = parser . parse_args ( )
2020-08-30 10:04:52 +02:00
api = create ( None , args . search_space , fast_mode = True , verbose = False )
2020-07-13 04:53:11 +02:00
2020-11-26 07:43:28 +01:00
args . save_dir = os . path . join ( ' {:} - {:} ' . format ( args . save_dir , args . search_space ) ,
2020-12-01 05:34:00 +01:00
' {:} -T {:} {:} ' . format ( args . dataset , args . time_budget , ' ' if args . use_proxy > 0 else ' -FULL ' ) ,
' R-EA-SS {:} ' . format ( args . ea_sample_size ) )
2020-07-13 04:53:11 +02:00
print ( ' save-dir : {:} ' . format ( args . save_dir ) )
2020-07-14 08:10:34 +02:00
print ( ' xargs : {:} ' . format ( args ) )
2020-07-13 04:53:11 +02:00
if args . rand_seed < 0 :
2020-07-13 13:35:13 +02:00
save_dir , all_info = None , collections . OrderedDict ( )
2020-07-13 05:43:10 +02:00
for i in range ( args . loops_if_rand ) :
print ( ' {:} : {:03d} / {:03d} ' . format ( time_string ( ) , i , args . loops_if_rand ) )
2020-07-13 04:53:11 +02:00
args . rand_seed = random . randint ( 1 , 100000 )
save_dir , all_archs , all_total_times = main ( args , api )
all_info [ i ] = { ' all_archs ' : all_archs ,
' all_total_times ' : all_total_times }
2020-07-13 05:43:10 +02:00
save_path = save_dir / ' results.pth '
print ( ' save into {:} ' . format ( save_path ) )
torch . save ( all_info , save_path )
2020-07-13 04:53:11 +02:00
else :
main ( args , api )