from collections import namedtuple Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') PRIMITIVES = [ 'none', 'max_pool_3x3', 'avg_pool_3x3', 'skip_connect', 'sep_conv_3x3', 'sep_conv_5x5', 'dil_conv_3x3', 'dil_conv_5x5' ] NASNet = Genotype( normal = [ ('sep_conv_5x5', 1, 1.0), ('sep_conv_3x3', 0, 1.0), ('sep_conv_5x5', 0, 1.0), ('sep_conv_3x3', 0, 1.0), ('avg_pool_3x3', 1, 1.0), ('skip_connect', 0, 1.0), ('avg_pool_3x3', 0, 1.0), ('avg_pool_3x3', 0, 1.0), ('sep_conv_3x3', 1, 1.0), ('skip_connect', 1, 1.0), ], normal_concat = [2, 3, 4, 5, 6], reduce = [ ('sep_conv_5x5', 1, 1.0), ('sep_conv_7x7', 0, 1.0), ('max_pool_3x3', 1, 1.0), ('sep_conv_7x7', 0, 1.0), ('avg_pool_3x3', 1, 1.0), ('sep_conv_5x5', 0, 1.0), ('skip_connect', 3, 1.0), ('avg_pool_3x3', 2, 1.0), ('sep_conv_3x3', 2, 1.0), ('max_pool_3x3', 1, 1.0), ], reduce_concat = [4, 5, 6], ) AmoebaNet = Genotype( normal = [ ('avg_pool_3x3', 0, 1.0), ('max_pool_3x3', 1, 1.0), ('sep_conv_3x3', 0, 1.0), ('sep_conv_5x5', 2, 1.0), ('sep_conv_3x3', 0, 1.0), ('avg_pool_3x3', 3, 1.0), ('sep_conv_3x3', 1, 1.0), ('skip_connect', 1, 1.0), ('skip_connect', 0, 1.0), ('avg_pool_3x3', 1, 1.0), ], normal_concat = [4, 5, 6], reduce = [ ('avg_pool_3x3', 0, 1.0), ('sep_conv_3x3', 1, 1.0), ('max_pool_3x3', 0, 1.0), ('sep_conv_7x7', 2, 1.0), ('sep_conv_7x7', 0, 1.0), ('avg_pool_3x3', 1, 1.0), ('max_pool_3x3', 0, 1.0), ('max_pool_3x3', 1, 1.0), ('conv_7x1_1x7', 0, 1.0), ('sep_conv_3x3', 5, 1.0), ], reduce_concat = [3, 4, 6] ) DARTS_V1 = Genotype( normal=[ ('sep_conv_3x3', 1, 1.0), ('sep_conv_3x3', 0, 1.0), ('skip_connect', 0, 1.0), ('sep_conv_3x3', 1, 1.0), ('skip_connect', 0, 1.0), ('sep_conv_3x3', 1, 1.0), ('sep_conv_3x3', 0, 1.0), ('skip_connect', 2, 1.0)], normal_concat=[2, 3, 4, 5], reduce=[ ('max_pool_3x3', 0, 1.0), ('max_pool_3x3', 1, 1.0), ('skip_connect', 2, 1.0), ('max_pool_3x3', 0, 1.0), ('max_pool_3x3', 0, 1.0), ('skip_connect', 2, 1.0), ('skip_connect', 2, 1.0), ('avg_pool_3x3', 0, 1.0)], reduce_concat=[2, 3, 4, 5] ) DARTS_V2 = Genotype( normal=[ ('sep_conv_3x3', 0, 1.0), ('sep_conv_3x3', 1, 1.0), ('sep_conv_3x3', 0, 1.0), ('sep_conv_3x3', 1, 1.0), ('sep_conv_3x3', 1, 1.0), ('skip_connect', 0, 1.0), ('skip_connect', 0, 1.0), ('dil_conv_3x3', 2, 1.0)], normal_concat=[2, 3, 4, 5], reduce=[ ('max_pool_3x3', 0, 1.0), ('max_pool_3x3', 1, 1.0), ('skip_connect', 2, 1.0), ('max_pool_3x3', 1, 1.0), ('max_pool_3x3', 0, 1.0), ('skip_connect', 2, 1.0), ('skip_connect', 2, 1.0), ('max_pool_3x3', 1, 1.0)], reduce_concat=[2, 3, 4, 5] ) PNASNet = Genotype( normal = [ ('sep_conv_5x5', 0, 1.0), ('max_pool_3x3', 0, 1.0), ('sep_conv_7x7', 1, 1.0), ('max_pool_3x3', 1, 1.0), ('sep_conv_5x5', 1, 1.0), ('sep_conv_3x3', 1, 1.0), ('sep_conv_3x3', 4, 1.0), ('max_pool_3x3', 1, 1.0), ('sep_conv_3x3', 0, 1.0), ('skip_connect', 1, 1.0), ], normal_concat = [2, 3, 4, 5, 6], reduce = [ ('sep_conv_5x5', 0, 1.0), ('max_pool_3x3', 0, 1.0), ('sep_conv_7x7', 1, 1.0), ('max_pool_3x3', 1, 1.0), ('sep_conv_5x5', 1, 1.0), ('sep_conv_3x3', 1, 1.0), ('sep_conv_3x3', 4, 1.0), ('max_pool_3x3', 1, 1.0), ('sep_conv_3x3', 0, 1.0), ('skip_connect', 1, 1.0), ], reduce_concat = [2, 3, 4, 5, 6], ) # https://arxiv.org/pdf/1802.03268.pdf ENASNet = Genotype( normal = [ ('sep_conv_3x3', 1, 1.0), ('skip_connect', 1, 1.0), ('sep_conv_5x5', 1, 1.0), ('skip_connect', 0, 1.0), ('avg_pool_3x3', 0, 1.0), ('sep_conv_3x3', 1, 1.0), ('sep_conv_3x3', 0, 1.0), ('avg_pool_3x3', 1, 1.0), ('sep_conv_5x5', 1, 1.0), ('avg_pool_3x3', 0, 1.0), ], normal_concat = [2, 3, 4, 5, 6], reduce = [ ('sep_conv_5x5', 0, 1.0), ('sep_conv_3x3', 1, 1.0), # 2 ('sep_conv_3x3', 1, 1.0), ('avg_pool_3x3', 1, 1.0), # 3 ('sep_conv_3x3', 1, 1.0), ('avg_pool_3x3', 1, 1.0), # 4 ('avg_pool_3x3', 1, 1.0), ('sep_conv_5x5', 4, 1.0), # 5 ('sep_conv_3x3', 5, 1.0), ('sep_conv_5x5', 0, 1.0), ], reduce_concat = [2, 3, 4, 5, 6], ) DARTS = DARTS_V2 # Search by normal and reduce GDAS_V1 = Genotype( normal=[('skip_connect', 0, 0.13017432391643524), ('skip_connect', 1, 0.12947972118854523), ('skip_connect', 0, 0.13062666356563568), ('sep_conv_5x5', 2, 0.12980839610099792), ('sep_conv_3x3', 3, 0.12923765182495117), ('skip_connect', 0, 0.12901571393013), ('sep_conv_5x5', 4, 0.12938997149467468), ('sep_conv_3x3', 3, 0.1289220005273819)], normal_concat=range(2, 6), reduce=[('sep_conv_5x5', 0, 0.12862831354141235), ('sep_conv_3x3', 1, 0.12783904373645782), ('sep_conv_5x5', 2, 0.12725995481014252), ('sep_conv_5x5', 1, 0.12705285847187042), ('dil_conv_5x5', 2, 0.12797553837299347), ('sep_conv_3x3', 1, 0.12737272679805756), ('sep_conv_5x5', 0, 0.12833961844444275), ('sep_conv_5x5', 1, 0.12758426368236542)], reduce_concat=range(2, 6) ) # Search by normal and fixing reduction GDAS_F1 = Genotype( normal=[('skip_connect', 0, 0.16), ('skip_connect', 1, 0.13), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.16), ('sep_conv_3x3', 2, 0.15)], normal_concat=[2, 3, 4, 5], reduce=None, reduce_concat=[2, 3, 4, 5], ) # Combine DMS_V1 and DMS_F1 GDAS_GF = Genotype( normal=[('skip_connect', 0, 0.13017432391643524), ('skip_connect', 1, 0.12947972118854523), ('skip_connect', 0, 0.13062666356563568), ('sep_conv_5x5', 2, 0.12980839610099792), ('sep_conv_3x3', 3, 0.12923765182495117), ('skip_connect', 0, 0.12901571393013), ('sep_conv_5x5', 4, 0.12938997149467468), ('sep_conv_3x3', 3, 0.1289220005273819)], normal_concat=range(2, 6), reduce=None, reduce_concat=range(2, 6) ) GDAS_FG = Genotype( normal=[('skip_connect', 0, 0.16), ('skip_connect', 1, 0.13), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.16), ('sep_conv_3x3', 2, 0.15)], normal_concat=range(2, 6), reduce=[('sep_conv_5x5', 0, 0.12862831354141235), ('sep_conv_3x3', 1, 0.12783904373645782), ('sep_conv_5x5', 2, 0.12725995481014252), ('sep_conv_5x5', 1, 0.12705285847187042), ('dil_conv_5x5', 2, 0.12797553837299347), ('sep_conv_3x3', 1, 0.12737272679805756), ('sep_conv_5x5', 0, 0.12833961844444275), ('sep_conv_5x5', 1, 0.12758426368236542)], reduce_concat=range(2, 6) ) model_types = {'DARTS_V1': DARTS_V1, 'DARTS_V2': DARTS_V2, 'NASNet' : NASNet, 'PNASNet' : PNASNet, 'AmoebaNet': AmoebaNet, 'ENASNet' : ENASNet, 'GDAS_V1' : GDAS_V1, 'GDAS_F1' : GDAS_F1, 'GDAS_GF' : GDAS_GF, 'GDAS_FG' : GDAS_FG}