update codes
This commit is contained in:
		| @@ -22,8 +22,9 @@ bash ./scripts-cnn/search-acc-v2.sh 3 acc2 | |||||||
|  |  | ||||||
| Train the searched CNN on CIFAR | Train the searched CNN on CIFAR | ||||||
| ``` | ``` | ||||||
| bash ./scripts-cnn/train-cifar.sh 0 GDAS_F1 cifar10 | bash ./scripts-cnn/train-cifar.sh 0 GDAS_FG cifar10  cut | ||||||
| bash ./scripts-cnn/train-cifar.sh 0 GDAS_V1 cifar100 | bash ./scripts-cnn/train-cifar.sh 0 GDAS_F1 cifar10  cut | ||||||
|  | bash ./scripts-cnn/train-cifar.sh 0 GDAS_V1 cifar100 cut | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| Train the searched CNN on ImageNet | Train the searched CNN on ImageNet | ||||||
|   | |||||||
| @@ -236,7 +236,6 @@ def train(train_queue, valid_queue, model, criterion, base_optimizer, arch_optim | |||||||
|  |  | ||||||
|     #inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) |     #inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) | ||||||
|     targets = targets.cuda(non_blocking=True) |     targets = targets.cuda(non_blocking=True) | ||||||
|     data_time.update(time.time() - end) |  | ||||||
|  |  | ||||||
|     # get a random minibatch from the search queue with replacement |     # get a random minibatch from the search queue with replacement | ||||||
|     try: |     try: | ||||||
| @@ -246,6 +245,7 @@ def train(train_queue, valid_queue, model, criterion, base_optimizer, arch_optim | |||||||
|       input_search, target_search = next(valid_iter) |       input_search, target_search = next(valid_iter) | ||||||
|      |      | ||||||
|     target_search = target_search.cuda(non_blocking=True) |     target_search = target_search.cuda(non_blocking=True) | ||||||
|  |     data_time.update(time.time() - end) | ||||||
|  |  | ||||||
|     # update the architecture |     # update the architecture | ||||||
|     arch_optimizer.zero_grad() |     arch_optimizer.zero_grad() | ||||||
|   | |||||||
| @@ -195,12 +195,18 @@ GDAS_F1 = Genotype( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| # Combine DMS_V1 and DMS_F1 | # Combine DMS_V1 and DMS_F1 | ||||||
| GDAS_CC = Genotype( | GDAS_GF = Genotype( | ||||||
|   normal=[('skip_connect', 0, 0.13017432391643524), ('skip_connect', 1, 0.12947972118854523), ('skip_connect', 0, 0.13062666356563568), ('sep_conv_5x5', 2, 0.12980839610099792), ('sep_conv_3x3', 3, 0.12923765182495117), ('skip_connect', 0, 0.12901571393013), ('sep_conv_5x5', 4, 0.12938997149467468), ('sep_conv_3x3', 3, 0.1289220005273819)], |   normal=[('skip_connect', 0, 0.13017432391643524), ('skip_connect', 1, 0.12947972118854523), ('skip_connect', 0, 0.13062666356563568), ('sep_conv_5x5', 2, 0.12980839610099792), ('sep_conv_3x3', 3, 0.12923765182495117), ('skip_connect', 0, 0.12901571393013), ('sep_conv_5x5', 4, 0.12938997149467468), ('sep_conv_3x3', 3, 0.1289220005273819)], | ||||||
|   normal_concat=range(2, 6), |   normal_concat=range(2, 6), | ||||||
|   reduce=None, |   reduce=None, | ||||||
|   reduce_concat=range(2, 6) |   reduce_concat=range(2, 6) | ||||||
| ) | ) | ||||||
|  | GDAS_FG = Genotype( | ||||||
|  |   normal=[('skip_connect', 0, 0.16), ('skip_connect', 1, 0.13), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.17), ('sep_conv_3x3', 2, 0.15), ('skip_connect', 0, 0.16), ('sep_conv_3x3', 2, 0.15)], | ||||||
|  |   normal_concat=range(2, 6), | ||||||
|  |   reduce=[('sep_conv_5x5', 0, 0.12862831354141235), ('sep_conv_3x3', 1, 0.12783904373645782), ('sep_conv_5x5', 2, 0.12725995481014252), ('sep_conv_5x5', 1, 0.12705285847187042), ('dil_conv_5x5', 2, 0.12797553837299347), ('sep_conv_3x3', 1, 0.12737272679805756), ('sep_conv_5x5', 0, 0.12833961844444275), ('sep_conv_5x5', 1, 0.12758426368236542)], | ||||||
|  |   reduce_concat=range(2, 6) | ||||||
|  | ) | ||||||
|  |  | ||||||
| model_types = {'DARTS_V1': DARTS_V1, | model_types = {'DARTS_V1': DARTS_V1, | ||||||
|                'DARTS_V2': DARTS_V2, |                'DARTS_V2': DARTS_V2, | ||||||
| @@ -210,4 +216,5 @@ model_types = {'DARTS_V1': DARTS_V1, | |||||||
|                'ENASNet' : ENASNet, |                'ENASNet' : ENASNet, | ||||||
|                'GDAS_V1' : GDAS_V1, |                'GDAS_V1' : GDAS_V1, | ||||||
|                'GDAS_F1' : GDAS_F1, |                'GDAS_F1' : GDAS_F1, | ||||||
|                'GDAS_CC' : GDAS_CC} |                'GDAS_GF' : GDAS_GF, | ||||||
|  |                'GDAS_FG' : GDAS_FG} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user