update API
This commit is contained in:
		| @@ -123,6 +123,13 @@ api.reload('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-BENCH-201-4-v1.0-arch | |||||||
| weights = api.get_net_param(3, 'cifar10', None) # Obtaining the weights of all trials for the 3-th architecture on cifar10. It will returns a dict, where the key is the seed and the value is the trained weights. | weights = api.get_net_param(3, 'cifar10', None) # Obtaining the weights of all trials for the 3-th architecture on cifar10. It will returns a dict, where the key is the seed and the value is the trained weights. | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
|  | To obtain the training and evaluation information (please see the comments [here](https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/nas_201_api/api.py#L172)): | ||||||
|  | ``` | ||||||
|  | api.get_more_info(112, 'cifar10', None, False, True) | ||||||
|  | api.get_more_info(112, 'ImageNet16-120', None, False, True) # the info of last training epoch for 112-th architecture (use 200-epoch-hyper-parameter and randomly select a trial) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| ## Instruction to Re-Generate NAS-Bench-201 | ## Instruction to Re-Generate NAS-Bench-201 | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,206 +0,0 @@ | |||||||
| # [D-X-Y] |  | ||||||
| # Run DARTS |  | ||||||
| # CUDA_VISIBLE_DEVICES=0 python exps-tf/one-shot-nas.py --epochs 50 |  | ||||||
| # |  | ||||||
| import os, sys, math, time, random, argparse |  | ||||||
| import tensorflow as tf |  | ||||||
| from pathlib import Path |  | ||||||
|  |  | ||||||
| lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() |  | ||||||
| if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) |  | ||||||
|  |  | ||||||
| # self-lib |  | ||||||
| from tf_models import get_cell_based_tiny_net |  | ||||||
| from tf_optimizers import SGDW, AdamW |  | ||||||
| from config_utils import dict2config |  | ||||||
| from log_utils import time_string |  | ||||||
| from models import CellStructure |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def pre_process(image_a, label_a, image_b, label_b): |  | ||||||
|   def standard_func(image): |  | ||||||
|     x = tf.pad(image, [[4, 4], [4, 4], [0, 0]]) |  | ||||||
|     x = tf.image.random_crop(x, [32, 32, 3]) |  | ||||||
|     x = tf.image.random_flip_left_right(x) |  | ||||||
|     return x |  | ||||||
|   return standard_func(image_a), label_a, standard_func(image_b), label_b |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class CosineAnnealingLR(object): |  | ||||||
|   def __init__(self, warmup_epochs, epochs, initial_lr, min_lr): |  | ||||||
|     self.warmup_epochs = warmup_epochs |  | ||||||
|     self.epochs = epochs |  | ||||||
|     self.initial_lr = initial_lr |  | ||||||
|     self.min_lr = min_lr |  | ||||||
|  |  | ||||||
|   def get_lr(self, epoch): |  | ||||||
|     if epoch < self.warmup_epochs: |  | ||||||
|       lr = self.min_lr + (epoch/self.warmup_epochs) * (self.initial_lr-self.min_lr) |  | ||||||
|     elif epoch >= self.epochs: |  | ||||||
|       lr = self.min_lr |  | ||||||
|     else: |  | ||||||
|       lr = self.min_lr + (self.initial_lr-self.min_lr) * 0.5 * (1 + math.cos(math.pi * epoch / self.epochs)) |  | ||||||
|     return lr |  | ||||||
|        |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(xargs): |  | ||||||
|   cifar10 = tf.keras.datasets.cifar10 |  | ||||||
|  |  | ||||||
|   (x_train, y_train), (x_test, y_test) = cifar10.load_data() |  | ||||||
|   x_train, x_test = x_train / 255.0, x_test / 255.0 |  | ||||||
|   x_train, x_test = x_train.astype('float32'), x_test.astype('float32') |  | ||||||
|   y_train, y_test = y_train.reshape(-1), y_test.reshape(-1) |  | ||||||
|  |  | ||||||
|   # Add a channels dimension |  | ||||||
|   all_indexes = list(range(x_train.shape[0])) |  | ||||||
|   random.shuffle(all_indexes) |  | ||||||
|   s_train_idxs, s_valid_idxs = all_indexes[::2], all_indexes[1::2] |  | ||||||
|   search_train_x, search_train_y = x_train[s_train_idxs], y_train[s_train_idxs] |  | ||||||
|   search_valid_x, search_valid_y = x_train[s_valid_idxs], y_train[s_valid_idxs] |  | ||||||
|   #x_train, x_test = x_train[..., tf.newaxis], x_test[..., tf.newaxis] |  | ||||||
|    |  | ||||||
|   # Use tf.data |  | ||||||
|   #train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(64) |  | ||||||
|   search_ds = tf.data.Dataset.from_tensor_slices((search_train_x, search_train_y, search_valid_x, search_valid_y)) |  | ||||||
|   search_ds = search_ds.map(pre_process).shuffle(1000).batch(64) |  | ||||||
|  |  | ||||||
|   test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32) |  | ||||||
|  |  | ||||||
|   # Create an instance of the model |  | ||||||
|   config = dict2config({'name': 'DARTS', |  | ||||||
|                         'C'   : xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, |  | ||||||
|                         'num_classes': 10, 'space': 'nas-bench-201', 'affine': True}, None) |  | ||||||
|   model = get_cell_based_tiny_net(config) |  | ||||||
|   num_iters_per_epoch = int(tf.data.experimental.cardinality(search_ds).numpy()) |  | ||||||
|   #lr_schedular = tf.keras.experimental.CosineDecay(xargs.w_lr_max, num_iters_per_epoch*xargs.epochs, xargs.w_lr_min / xargs.w_lr_max) |  | ||||||
|   lr_schedular = CosineAnnealingLR(0, xargs.epochs, xargs.w_lr_max, xargs.w_lr_min) |  | ||||||
|   # Choose optimizer |  | ||||||
|   loss_object = tf.keras.losses.CategoricalCrossentropy() |  | ||||||
|   w_optimizer = SGDW(learning_rate=xargs.w_lr_max, weight_decay=xargs.w_weight_decay, momentum=xargs.w_momentum, nesterov=True) |  | ||||||
|   a_optimizer = AdamW(learning_rate=xargs.arch_learning_rate, weight_decay=xargs.arch_weight_decay, beta_1=0.5, beta_2=0.999, epsilon=1e-07) |  | ||||||
|   #w_optimizer = tf.keras.optimizers.SGD(learning_rate=0.025, momentum=0.9, nesterov=True) |  | ||||||
|   #a_optimizer = tf.keras.optimizers.AdamW(learning_rate=xargs.arch_learning_rate, beta_1=0.5, beta_2=0.999, epsilon=1e-07) |  | ||||||
|   #### |  | ||||||
|   # metrics |  | ||||||
|   train_loss = tf.keras.metrics.Mean(name='train_loss') |  | ||||||
|   train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy') |  | ||||||
|   valid_loss = tf.keras.metrics.Mean(name='valid_loss') |  | ||||||
|   valid_accuracy = tf.keras.metrics.CategoricalAccuracy(name='valid_accuracy') |  | ||||||
|   test_loss = tf.keras.metrics.Mean(name='test_loss') |  | ||||||
|   test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') |  | ||||||
|    |  | ||||||
|   @tf.function |  | ||||||
|   def search_step(train_images, train_labels, valid_images, valid_labels): |  | ||||||
|     # optimize weights |  | ||||||
|     with tf.GradientTape() as tape: |  | ||||||
|       predictions = model(train_images, True) |  | ||||||
|       w_loss = loss_object(train_labels, predictions) |  | ||||||
|     net_w_param = model.get_weights() |  | ||||||
|     gradients = tape.gradient(w_loss, net_w_param) |  | ||||||
|     w_optimizer.apply_gradients(zip(gradients, net_w_param)) |  | ||||||
|     train_loss(w_loss) |  | ||||||
|     train_accuracy(train_labels, predictions) |  | ||||||
|     # optimize alphas |  | ||||||
|     with tf.GradientTape() as tape: |  | ||||||
|       predictions = model(valid_images, True) |  | ||||||
|       a_loss = loss_object(valid_labels, predictions) |  | ||||||
|     net_a_param = model.get_alphas() |  | ||||||
|     gradients = tape.gradient(a_loss, net_a_param) |  | ||||||
|     a_optimizer.apply_gradients(zip(gradients, net_a_param)) |  | ||||||
|     valid_loss(a_loss) |  | ||||||
|     valid_accuracy(valid_labels, predictions) |  | ||||||
|  |  | ||||||
|   # IFT with Neumann approximation |  | ||||||
|   @tf.function |  | ||||||
|   def search_step_IFTNA(train_images, train_labels, valid_images, valid_labels, max_step): |  | ||||||
|     # optimize weights |  | ||||||
|     with tf.GradientTape() as tape: |  | ||||||
|       predictions = model(train_images, True) |  | ||||||
|       w_loss = loss_object(train_labels, predictions) |  | ||||||
|     # get the weights |  | ||||||
|     net_w_param = model.get_weights() |  | ||||||
|     net_a_param = model.get_alphas() |  | ||||||
|     gradients = tape.gradient(w_loss, net_w_param) |  | ||||||
|     w_optimizer.apply_gradients(zip(gradients, net_w_param)) |  | ||||||
|     train_loss(w_loss) |  | ||||||
|     train_accuracy(train_labels, predictions) |  | ||||||
|     # optimize alphas |  | ||||||
|     with tf.GradientTape(persistent=True) as tape: |  | ||||||
|       predictions = model(valid_images, True) |  | ||||||
|       val_loss = loss_object(valid_labels, predictions) |  | ||||||
|       predictions = model(train_images, True) |  | ||||||
|       trn_loss = loss_object(train_labels, predictions) |  | ||||||
|       # ---- |  | ||||||
|       dV_dW = tape.gradient(val_loss, net_w_param) |  | ||||||
|       # approxInverseHVP to calculate v2 |  | ||||||
|       sum_p = v1 = dV_dW |  | ||||||
|       dT_dW = tape.gradient(trn_loss, net_w_param) |  | ||||||
|       for j in range(1, max_step): |  | ||||||
|         temp_dot = tape.gradient(dT_dW, net_w_param, output_gradients=v1) |  | ||||||
|         v1 = [tf.subtract(A, B) for A, B in zip(v1, temp_dot)] |  | ||||||
|         sum_p = [tf.add(A, B) for A, B in zip(sum_p, v1)] |  | ||||||
|       # calculate v3 |  | ||||||
|       dT_dl = tape.gradient(trn_loss, net_a_param) |  | ||||||
|       import pdb; pdb.set_trace() |  | ||||||
|       v3 = tape.gradient(dT_dl, net_w_param, output_gradients=sum_p) |  | ||||||
|     dV_dl = tape.gradient(val_loss, net_a_param) |  | ||||||
|     a_gradients = [tf.subtract(A, B) for A, B in zip(dV_dl, v3)] |  | ||||||
|     import pdb; pdb.set_trace() |  | ||||||
|     print('--') |  | ||||||
|  |  | ||||||
|   # TEST |  | ||||||
|   @tf.function |  | ||||||
|   def test_step(images, labels): |  | ||||||
|     predictions = model(images) |  | ||||||
|     t_loss = loss_object(labels, predictions) |  | ||||||
|  |  | ||||||
|     test_loss(t_loss) |  | ||||||
|     test_accuracy(labels, predictions) |  | ||||||
|  |  | ||||||
|   print('{:} start searching with {:} epochs ({:} batches per epoch).'.format(time_string(), xargs.epochs, num_iters_per_epoch)) |  | ||||||
|  |  | ||||||
|   for epoch in range(xargs.epochs): |  | ||||||
|     # Reset the metrics at the start of the next epoch |  | ||||||
|     train_loss.reset_states() ; train_accuracy.reset_states() |  | ||||||
|     test_loss.reset_states()  ; test_accuracy.reset_states() |  | ||||||
|     cur_lr  = lr_schedular.get_lr(epoch) |  | ||||||
|     tf.keras.backend.set_value(w_optimizer.lr, cur_lr) |  | ||||||
|  |  | ||||||
|     for trn_imgs, trn_labels, val_imgs, val_labels in search_ds: |  | ||||||
|       #search_step(trn_imgs, trn_labels, val_imgs, val_labels) |  | ||||||
|       trn_labels, val_labels = tf.one_hot(trn_labels, 10), tf.one_hot(val_labels, 10) |  | ||||||
|       search_step_IFTNA(trn_imgs, trn_labels, val_imgs, val_labels, 5) |  | ||||||
|     genotype = model.genotype() |  | ||||||
|     genotype = CellStructure(genotype) |  | ||||||
|  |  | ||||||
|     #for test_images, test_labels in test_ds: |  | ||||||
|     #  test_step(test_images, test_labels) |  | ||||||
|  |  | ||||||
|     cur_lr = float(tf.keras.backend.get_value(w_optimizer.lr)) |  | ||||||
|     template = '{:} Epoch {:03d}/{:03d}, Train-Loss: {:.3f}, Train-Accuracy: {:.2f}%, Valid-Loss: {:.3f}, Valid-Accuracy: {:.2f}% | lr={:.6f}' |  | ||||||
|     print(template.format(time_string(), epoch+1, xargs.epochs, |  | ||||||
|                           train_loss.result(), |  | ||||||
|                           train_accuracy.result()*100, |  | ||||||
|                           valid_loss.result(), |  | ||||||
|                           valid_accuracy.result()*100, |  | ||||||
|                           cur_lr)) |  | ||||||
|     print('{:} genotype : {:}\n{:}\n'.format(time_string(), genotype, model.get_np_alphas())) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': |  | ||||||
|   parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |  | ||||||
|   # training details |  | ||||||
|   parser.add_argument('--epochs'            , type=int  ,   default= 250  ,   help='') |  | ||||||
|   parser.add_argument('--w_lr_max'          , type=float,   default= 0.025,   help='') |  | ||||||
|   parser.add_argument('--w_lr_min'          , type=float,   default= 0.001,   help='') |  | ||||||
|   parser.add_argument('--w_weight_decay'    , type=float,   default=0.0005,   help='') |  | ||||||
|   parser.add_argument('--w_momentum'        , type=float,   default= 0.9  ,   help='') |  | ||||||
|   parser.add_argument('--arch_learning_rate', type=float,   default=0.0003,   help='') |  | ||||||
|   parser.add_argument('--arch_weight_decay' , type=float,   default=0.001,    help='') |  | ||||||
|   # marco structure |  | ||||||
|   parser.add_argument('--channel'           , type=int  ,   default=16,       help='') |  | ||||||
|   parser.add_argument('--num_cells'         , type=int  ,   default= 5,       help='') |  | ||||||
|   parser.add_argument('--max_nodes'         , type=int  ,   default= 4,       help='') |  | ||||||
|   args = parser.parse_args() |  | ||||||
|   main( args ) |  | ||||||
| @@ -170,10 +170,28 @@ class NASBench201API(object): | |||||||
|     return archresult.get_comput_costs(dataset) |     return archresult.get_comput_costs(dataset) | ||||||
|  |  | ||||||
|   # obtain the metric for the `index`-th architecture |   # obtain the metric for the `index`-th architecture | ||||||
|  |   # `dataset` indicates the dataset: | ||||||
|  |   #   'cifar10-valid'  : using the proposed train set of CIFAR-10 as the training set | ||||||
|  |   #   'cifar10'        : using the proposed train+valid set of CIFAR-10 as the training set | ||||||
|  |   #   'cifar100'       : using the proposed train set of CIFAR-100 as the training set | ||||||
|  |   #   'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set | ||||||
|  |   # `iepoch` indicates the index of training epochs from 0 to 11/199. | ||||||
|  |   #   When iepoch=None, it will return the metric for the last training epoch | ||||||
|  |   #   When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0) | ||||||
|  |   # `use_12epochs_result` indicates different hyper-parameters for training | ||||||
|  |   #   When use_12epochs_result=True, it trains the network with 12 epochs and the LR decayed from 0.1 to 0 within 12 epochs | ||||||
|  |   #   When use_12epochs_result=False, it trains the network with 200 epochs and the LR decayed from 0.1 to 0 within 200 epochs | ||||||
|  |   # `is_random` | ||||||
|  |   #   When is_random=True, the performance of a random architecture will be returned | ||||||
|  |   #   When is_random=False, the performanceo of all trials will be averaged. | ||||||
|   def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True): |   def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True): | ||||||
|     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less |     if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less | ||||||
|     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full |     else                  : basestr, arch2infos = '200epochs', self.arch2infos_full | ||||||
|     archresult = arch2infos[index] |     archresult = arch2infos[index] | ||||||
|  |     # if randomly select one trial, select the seed at first | ||||||
|  |     if isinstance(is_random, bool) and is_random: | ||||||
|  |       seeds = archresult.get_dataset_seeds(dataset) | ||||||
|  |       is_random = random.choice(seeds) | ||||||
|     if dataset == 'cifar10-valid': |     if dataset == 'cifar10-valid': | ||||||
|       train_info = archresult.get_metrics(dataset, 'train'   , iepoch=iepoch, is_random=is_random) |       train_info = archresult.get_metrics(dataset, 'train'   , iepoch=iepoch, is_random=is_random) | ||||||
|       valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=is_random) |       valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=is_random) | ||||||
| @@ -202,7 +220,7 @@ class NASBench201API(object): | |||||||
|         else: |         else: | ||||||
|           test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) |           test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random) | ||||||
|       except: |       except: | ||||||
|         valid_info = None |         test__info = None | ||||||
|       try: |       try: | ||||||
|         valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) |         valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random) | ||||||
|       except: |       except: | ||||||
| @@ -213,7 +231,7 @@ class NASBench201API(object): | |||||||
|         est_valid_info = None |         est_valid_info = None | ||||||
|       xifo = {'train-loss'    : train_info['loss'], |       xifo = {'train-loss'    : train_info['loss'], | ||||||
|               'train-accuracy': train_info['accuracy']} |               'train-accuracy': train_info['accuracy']} | ||||||
|       if valid_info is not None: |       if test__info is not None: | ||||||
|         xifo['test-loss'] = test__info['loss'], |         xifo['test-loss'] = test__info['loss'], | ||||||
|         xifo['test-accuracy'] = test__info['accuracy'] |         xifo['test-accuracy'] = test__info['accuracy'] | ||||||
|       if valid_info is not None: |       if valid_info is not None: | ||||||
| @@ -347,14 +365,20 @@ class ArchResults(object): | |||||||
|         info = result.get_eval(setname, iepoch) |         info = result.get_eval(setname, iepoch) | ||||||
|       for key, value in info.items(): infos[key].append( value ) |       for key, value in info.items(): infos[key].append( value ) | ||||||
|     return_info = dict() |     return_info = dict() | ||||||
|     if is_random: |     if isinstance(is_random, bool) and is_random: # randomly select one | ||||||
|       index = random.randint(0, len(results)-1) |       index = random.randint(0, len(results)-1) | ||||||
|       for key, value in infos.items(): return_info[key] = value[index] |       for key, value in infos.items(): return_info[key] = value[index] | ||||||
|     else: |     elif isinstance(is_random, bool) and not is_random: # average | ||||||
|       for key, value in infos.items(): |       for key, value in infos.items(): | ||||||
|         if len(value) > 0 and value[0] is not None: |         if len(value) > 0 and value[0] is not None: | ||||||
|           return_info[key] = np.mean(value) |           return_info[key] = np.mean(value) | ||||||
|         else: return_info[key] = None |         else: return_info[key] = None | ||||||
|  |     elif isinstance(is_random, int): # specify the seed | ||||||
|  |       if is_random not in x_seeds: raise ValueError('can not find random seed ({:}) from {:}'.format(is_random, x_seeds)) | ||||||
|  |       index = x_seeds.index(is_random) | ||||||
|  |       for key, value in infos.items(): return_info[key] = value[index] | ||||||
|  |     else: | ||||||
|  |       raise ValueError('invalid value for is_random: {:}'.format(is_random)) | ||||||
|     return return_info |     return return_info | ||||||
|  |  | ||||||
|   def show(self, is_print=False): |   def show(self, is_print=False): | ||||||
| @@ -363,6 +387,9 @@ class ArchResults(object): | |||||||
|   def get_dataset_names(self): |   def get_dataset_names(self): | ||||||
|     return list(self.dataset_seed.keys()) |     return list(self.dataset_seed.keys()) | ||||||
|  |  | ||||||
|  |   def get_dataset_seeds(self, dataset): | ||||||
|  |     return copy.deepcopy( self.dataset_seed[dataset] ) | ||||||
|  |  | ||||||
|   def get_net_param(self, dataset, seed=None): |   def get_net_param(self, dataset, seed=None): | ||||||
|     if seed is None: |     if seed is None: | ||||||
|       x_seeds = self.dataset_seed[dataset] |       x_seeds = self.dataset_seed[dataset] | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user