update tf-GDAS
This commit is contained in:
		| @@ -1,4 +1,7 @@ | |||||||
| # Automated Deep Learning (AutoDL) | <p align="center"> | ||||||
|  | <img src="https://xuanyidong.com/resources/images/AutoDL-log.png" width="400"/> | ||||||
|  | </p> | ||||||
|  |  | ||||||
| --------- | --------- | ||||||
| [](LICENSE.md) | [](LICENSE.md) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,5 +1,5 @@ | |||||||
| # CUDA_VISIBLE_DEVICES=0 python exps-tf/GDAS.py | # CUDA_VISIBLE_DEVICES=0 python exps-tf/GDAS.py | ||||||
| import os, sys, time, random, argparse | import os, sys, math, time, random, argparse | ||||||
| import tensorflow as tf | import tensorflow as tf | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
| @@ -23,6 +23,24 @@ def pre_process(image_a, label_a, image_b, label_b): | |||||||
|   return standard_func(image_a), label_a, standard_func(image_b), label_b |   return standard_func(image_a), label_a, standard_func(image_b), label_b | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class CosineAnnealingLR(object): | ||||||
|  |   def __init__(self, warmup_epochs, epochs, initial_lr, min_lr): | ||||||
|  |     self.warmup_epochs = warmup_epochs | ||||||
|  |     self.epochs = epochs | ||||||
|  |     self.initial_lr = initial_lr | ||||||
|  |     self.min_lr = min_lr | ||||||
|  |  | ||||||
|  |   def get_lr(self, epoch): | ||||||
|  |     if epoch < self.warmup_epochs: | ||||||
|  |       lr = self.min_lr + (epoch/self.warmup_epochs) * (self.initial_lr-self.min_lr) | ||||||
|  |     elif epoch >= self.epochs: | ||||||
|  |       lr = self.min_lr | ||||||
|  |     else: | ||||||
|  |       lr = self.min_lr + (self.initial_lr-self.min_lr) * 0.5 * (1 + math.cos(math.pi * epoch / self.epochs)) | ||||||
|  |     return lr | ||||||
|  |        | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(xargs): | def main(xargs): | ||||||
|   cifar10 = tf.keras.datasets.cifar10 |   cifar10 = tf.keras.datasets.cifar10 | ||||||
|  |  | ||||||
| @@ -50,13 +68,12 @@ def main(xargs): | |||||||
|                         'C'   : xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, |                         'C'   : xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, | ||||||
|                         'num_classes': 10, 'space': 'nas-bench-201', 'affine': True}, None) |                         'num_classes': 10, 'space': 'nas-bench-201', 'affine': True}, None) | ||||||
|   model = get_cell_based_tiny_net(config) |   model = get_cell_based_tiny_net(config) | ||||||
|   #import pdb; pdb.set_trace() |   num_iters_per_epoch = int(tf.data.experimental.cardinality(search_ds).numpy()) | ||||||
|   #model.build(((64, 32, 32, 3), (1,))) |   #lr_schedular = tf.keras.experimental.CosineDecay(xargs.w_lr_max, num_iters_per_epoch*xargs.epochs, xargs.w_lr_min / xargs.w_lr_max) | ||||||
|   #for x in model.trainable_variables: |   lr_schedular = CosineAnnealingLR(0, xargs.epochs, xargs.w_lr_max, xargs.w_lr_min) | ||||||
|   #  print('{:30s} : {:}'.format(x.name, x.shape)) |  | ||||||
|   # Choose optimizer |   # Choose optimizer | ||||||
|   loss_object = tf.keras.losses.SparseCategoricalCrossentropy() |   loss_object = tf.keras.losses.SparseCategoricalCrossentropy() | ||||||
|   w_optimizer = SGDW(learning_rate=xargs.w_lr, weight_decay=xargs.w_weight_decay, momentum=xargs.w_momentum, nesterov=True) |   w_optimizer = SGDW(learning_rate=xargs.w_lr_max, weight_decay=xargs.w_weight_decay, momentum=xargs.w_momentum, nesterov=True) | ||||||
|   a_optimizer = AdamW(learning_rate=xargs.arch_learning_rate, weight_decay=xargs.arch_weight_decay, beta_1=0.5, beta_2=0.999, epsilon=1e-07) |   a_optimizer = AdamW(learning_rate=xargs.arch_learning_rate, weight_decay=xargs.arch_weight_decay, beta_1=0.5, beta_2=0.999, epsilon=1e-07) | ||||||
|   #w_optimizer = tf.keras.optimizers.SGD(learning_rate=0.025, momentum=0.9, nesterov=True) |   #w_optimizer = tf.keras.optimizers.SGD(learning_rate=0.025, momentum=0.9, nesterov=True) | ||||||
|   #a_optimizer = tf.keras.optimizers.AdamW(learning_rate=xargs.arch_learning_rate, beta_1=0.5, beta_2=0.999, epsilon=1e-07) |   #a_optimizer = tf.keras.optimizers.AdamW(learning_rate=xargs.arch_learning_rate, beta_1=0.5, beta_2=0.999, epsilon=1e-07) | ||||||
| @@ -99,7 +116,7 @@ def main(xargs): | |||||||
|     test_loss(t_loss) |     test_loss(t_loss) | ||||||
|     test_accuracy(labels, predictions) |     test_accuracy(labels, predictions) | ||||||
|  |  | ||||||
|   print('{:} start searching with {:} epochs ({:} batches per epoch).'.format(time_string(), xargs.epochs, tf.data.experimental.cardinality(search_ds).numpy())) |   print('{:} start searching with {:} epochs ({:} batches per epoch).'.format(time_string(), xargs.epochs, num_iters_per_epoch)) | ||||||
|  |  | ||||||
|   for epoch in range(xargs.epochs): |   for epoch in range(xargs.epochs): | ||||||
|     # Reset the metrics at the start of the next epoch |     # Reset the metrics at the start of the next epoch | ||||||
| @@ -107,6 +124,8 @@ def main(xargs): | |||||||
|     test_loss.reset_states()  ; test_accuracy.reset_states() |     test_loss.reset_states()  ; test_accuracy.reset_states() | ||||||
|     cur_tau = xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (xargs.epochs-1) |     cur_tau = xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (xargs.epochs-1) | ||||||
|     tf_tau  = tf.cast(cur_tau, dtype=tf.float32, name='tau') |     tf_tau  = tf.cast(cur_tau, dtype=tf.float32, name='tau') | ||||||
|  |     cur_lr  = lr_schedular.get_lr(epoch) | ||||||
|  |     tf.keras.backend.set_value(w_optimizer.lr, cur_lr) | ||||||
|  |  | ||||||
|     for trn_imgs, trn_labels, val_imgs, val_labels in search_ds: |     for trn_imgs, trn_labels, val_imgs, val_labels in search_ds: | ||||||
|       search_step(trn_imgs, trn_labels, val_imgs, val_labels, tf_tau) |       search_step(trn_imgs, trn_labels, val_imgs, val_labels, tf_tau) | ||||||
| @@ -116,22 +135,26 @@ def main(xargs): | |||||||
|     #for test_images, test_labels in test_ds: |     #for test_images, test_labels in test_ds: | ||||||
|     #  test_step(test_images, test_labels) |     #  test_step(test_images, test_labels) | ||||||
|  |  | ||||||
|     template = '{:} Epoch {:03d}/{:03d}, Train-Loss: {:.3f}, Train-Accuracy: {:.2f}%, Valid-Loss: {:.3f}, Valid-Accuracy: {:.2f}% | tau={:.3f}' |     cur_lr = float(tf.keras.backend.get_value(w_optimizer.lr)) | ||||||
|  |     template = '{:} Epoch {:03d}/{:03d}, Train-Loss: {:.3f}, Train-Accuracy: {:.2f}%, Valid-Loss: {:.3f}, Valid-Accuracy: {:.2f}% | tau={:.3f} | lr={:.6f}' | ||||||
|     print(template.format(time_string(), epoch+1, xargs.epochs, |     print(template.format(time_string(), epoch+1, xargs.epochs, | ||||||
|                           train_loss.result(), |                           train_loss.result(), | ||||||
|                           train_accuracy.result()*100, |                           train_accuracy.result()*100, | ||||||
|                           valid_loss.result(), |                           valid_loss.result(), | ||||||
|                           valid_accuracy.result()*100, |                           valid_accuracy.result()*100, | ||||||
|                           cur_tau)) |                           cur_tau, | ||||||
|  |                           cur_lr)) | ||||||
|     print('{:} genotype : {:}\n{:}\n'.format(time_string(), genotype, model.get_np_alphas())) |     print('{:} genotype : {:}\n{:}\n'.format(time_string(), genotype, model.get_np_alphas())) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |   parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||||||
|   # training details |   # training details | ||||||
|   parser.add_argument('--epochs'            , type=int  ,   default= 250  ,   help='') |   parser.add_argument('--epochs'            , type=int  ,   default= 250  ,   help='') | ||||||
|   parser.add_argument('--tau_max'           , type=float,   default= 10   ,   help='') |   parser.add_argument('--tau_max'           , type=float,   default= 10   ,   help='') | ||||||
|   parser.add_argument('--tau_min'           , type=float,   default= 0.1  ,   help='') |   parser.add_argument('--tau_min'           , type=float,   default= 0.1  ,   help='') | ||||||
|   parser.add_argument('--w_lr'              , type=float,   default= 0.025,   help='') |   parser.add_argument('--w_lr_max'          , type=float,   default= 0.025,   help='') | ||||||
|  |   parser.add_argument('--w_lr_min'          , type=float,   default= 0.001,   help='') | ||||||
|   parser.add_argument('--w_weight_decay'    , type=float,   default=0.0005,   help='') |   parser.add_argument('--w_weight_decay'    , type=float,   default=0.0005,   help='') | ||||||
|   parser.add_argument('--w_momentum'        , type=float,   default= 0.9  ,   help='') |   parser.add_argument('--w_momentum'        , type=float,   default= 0.9  ,   help='') | ||||||
|   parser.add_argument('--arch_learning_rate', type=float,   default=0.0003,   help='') |   parser.add_argument('--arch_learning_rate', type=float,   default=0.0003,   help='') | ||||||
|   | |||||||
| @@ -11,7 +11,7 @@ OPS = { | |||||||
|   'nor_conv_1x1': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 1, stride, affine), |   'nor_conv_1x1': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 1, stride, affine), | ||||||
|   'nor_conv_3x3': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 3, stride, affine), |   'nor_conv_3x3': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 3, stride, affine), | ||||||
|   'nor_conv_5x5': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 5, stride, affine), |   'nor_conv_5x5': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 5, stride, affine), | ||||||
|   'skip_connect': lambda C_in, C_out, stride, affine: Identity(C_in, C_out, stride) |   'skip_connect': lambda C_in, C_out, stride, affine: Identity(C_in, C_out, stride) if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine) | ||||||
| } | } | ||||||
|  |  | ||||||
| NAS_BENCH_201         = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] | NAS_BENCH_201         = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] | ||||||
| @@ -87,6 +87,36 @@ class ReLUConvBN(tf.keras.layers.Layer): | |||||||
|     return x |     return x | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FactorizedReduce(tf.keras.layers.Layer): | ||||||
|  |   def __init__(self, C_in, C_out, stride, affine): | ||||||
|  |     assert output_filters % 2 == 0, ('Need even number of filters when using this factorized reduction.') | ||||||
|  |     self.stride == stride | ||||||
|  |     self.relu   = tf.keras.activations.relu | ||||||
|  |     if stride == 1: | ||||||
|  |       self.layer = tf.keras.Sequential([ | ||||||
|  |                           tf.keras.layers.Conv2D(C_out, 1, strides, padding='same', use_bias=False), | ||||||
|  |                           tf.keras.layers.BatchNormalization(center=affine, scale=affine)]) | ||||||
|  |     elif stride == 2: | ||||||
|  |       stride_spec = [1, stride, stride, 1] # data_format == 'NHWC' | ||||||
|  |       self.layer1 = tf.keras.layers.Conv2D(C_out//2, 1, strides, padding='same', use_bias=False) | ||||||
|  |       self.layer2 = tf.keras.layers.Conv2D(C_out//2, 1, strides, padding='same', use_bias=False) | ||||||
|  |       self.bn     = tf.keras.layers.BatchNormalization(center=affine, scale=affine) | ||||||
|  |     else: | ||||||
|  |       raise ValueError('invalid stride={:}'.format(stride)) | ||||||
|  |  | ||||||
|  |   def call(self, inputs, training): | ||||||
|  |     x = self.relu(inputs) | ||||||
|  |     if self.stride == 1: | ||||||
|  |       return self.layer(x, training) | ||||||
|  |     else: | ||||||
|  |       path1 = x | ||||||
|  |       path2 = tf.pad(x, [[0, 0], [0, 1], [0, 1], [0, 0]])[:, 1:, 1:, :] # data_format == 'NHWC' | ||||||
|  |       x1 = self.layer1(path1) | ||||||
|  |       x2 = self.layer2(path2) | ||||||
|  |       final_path = tf.concat(values=[x1, x2], axis=3) | ||||||
|  |       return self.bn(final_path) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ResNetBasicblock(tf.keras.layers.Layer): | class ResNetBasicblock(tf.keras.layers.Layer): | ||||||
|  |  | ||||||
|   def __init__(self, inplanes, planes, stride, affine=True): |   def __init__(self, inplanes, planes, stride, affine=True): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user