update API
This commit is contained in:
parent
533a508444
commit
3cd42e0ca1
@ -123,6 +123,13 @@ api.reload('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-BENCH-201-4-v1.0-arch
|
|||||||
weights = api.get_net_param(3, 'cifar10', None) # Obtaining the weights of all trials for the 3-th architecture on cifar10. It will returns a dict, where the key is the seed and the value is the trained weights.
|
weights = api.get_net_param(3, 'cifar10', None) # Obtaining the weights of all trials for the 3-th architecture on cifar10. It will returns a dict, where the key is the seed and the value is the trained weights.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To obtain the training and evaluation information (please see the comments [here](https://github.com/D-X-Y/AutoDL-Projects/blob/master/lib/nas_201_api/api.py#L172)):
|
||||||
|
```
|
||||||
|
api.get_more_info(112, 'cifar10', None, False, True)
|
||||||
|
api.get_more_info(112, 'ImageNet16-120', None, False, True) # the info of last training epoch for 112-th architecture (use 200-epoch-hyper-parameter and randomly select a trial)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Instruction to Re-Generate NAS-Bench-201
|
## Instruction to Re-Generate NAS-Bench-201
|
||||||
|
|
||||||
|
@ -1,206 +0,0 @@
|
|||||||
# [D-X-Y]
|
|
||||||
# Run DARTS
|
|
||||||
# CUDA_VISIBLE_DEVICES=0 python exps-tf/one-shot-nas.py --epochs 50
|
|
||||||
#
|
|
||||||
import os, sys, math, time, random, argparse
|
|
||||||
import tensorflow as tf
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
|
||||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
|
||||||
|
|
||||||
# self-lib
|
|
||||||
from tf_models import get_cell_based_tiny_net
|
|
||||||
from tf_optimizers import SGDW, AdamW
|
|
||||||
from config_utils import dict2config
|
|
||||||
from log_utils import time_string
|
|
||||||
from models import CellStructure
|
|
||||||
|
|
||||||
|
|
||||||
def pre_process(image_a, label_a, image_b, label_b):
|
|
||||||
def standard_func(image):
|
|
||||||
x = tf.pad(image, [[4, 4], [4, 4], [0, 0]])
|
|
||||||
x = tf.image.random_crop(x, [32, 32, 3])
|
|
||||||
x = tf.image.random_flip_left_right(x)
|
|
||||||
return x
|
|
||||||
return standard_func(image_a), label_a, standard_func(image_b), label_b
|
|
||||||
|
|
||||||
|
|
||||||
class CosineAnnealingLR(object):
|
|
||||||
def __init__(self, warmup_epochs, epochs, initial_lr, min_lr):
|
|
||||||
self.warmup_epochs = warmup_epochs
|
|
||||||
self.epochs = epochs
|
|
||||||
self.initial_lr = initial_lr
|
|
||||||
self.min_lr = min_lr
|
|
||||||
|
|
||||||
def get_lr(self, epoch):
|
|
||||||
if epoch < self.warmup_epochs:
|
|
||||||
lr = self.min_lr + (epoch/self.warmup_epochs) * (self.initial_lr-self.min_lr)
|
|
||||||
elif epoch >= self.epochs:
|
|
||||||
lr = self.min_lr
|
|
||||||
else:
|
|
||||||
lr = self.min_lr + (self.initial_lr-self.min_lr) * 0.5 * (1 + math.cos(math.pi * epoch / self.epochs))
|
|
||||||
return lr
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main(xargs):
|
|
||||||
cifar10 = tf.keras.datasets.cifar10
|
|
||||||
|
|
||||||
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
|
|
||||||
x_train, x_test = x_train / 255.0, x_test / 255.0
|
|
||||||
x_train, x_test = x_train.astype('float32'), x_test.astype('float32')
|
|
||||||
y_train, y_test = y_train.reshape(-1), y_test.reshape(-1)
|
|
||||||
|
|
||||||
# Add a channels dimension
|
|
||||||
all_indexes = list(range(x_train.shape[0]))
|
|
||||||
random.shuffle(all_indexes)
|
|
||||||
s_train_idxs, s_valid_idxs = all_indexes[::2], all_indexes[1::2]
|
|
||||||
search_train_x, search_train_y = x_train[s_train_idxs], y_train[s_train_idxs]
|
|
||||||
search_valid_x, search_valid_y = x_train[s_valid_idxs], y_train[s_valid_idxs]
|
|
||||||
#x_train, x_test = x_train[..., tf.newaxis], x_test[..., tf.newaxis]
|
|
||||||
|
|
||||||
# Use tf.data
|
|
||||||
#train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(64)
|
|
||||||
search_ds = tf.data.Dataset.from_tensor_slices((search_train_x, search_train_y, search_valid_x, search_valid_y))
|
|
||||||
search_ds = search_ds.map(pre_process).shuffle(1000).batch(64)
|
|
||||||
|
|
||||||
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
|
|
||||||
|
|
||||||
# Create an instance of the model
|
|
||||||
config = dict2config({'name': 'DARTS',
|
|
||||||
'C' : xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes,
|
|
||||||
'num_classes': 10, 'space': 'nas-bench-201', 'affine': True}, None)
|
|
||||||
model = get_cell_based_tiny_net(config)
|
|
||||||
num_iters_per_epoch = int(tf.data.experimental.cardinality(search_ds).numpy())
|
|
||||||
#lr_schedular = tf.keras.experimental.CosineDecay(xargs.w_lr_max, num_iters_per_epoch*xargs.epochs, xargs.w_lr_min / xargs.w_lr_max)
|
|
||||||
lr_schedular = CosineAnnealingLR(0, xargs.epochs, xargs.w_lr_max, xargs.w_lr_min)
|
|
||||||
# Choose optimizer
|
|
||||||
loss_object = tf.keras.losses.CategoricalCrossentropy()
|
|
||||||
w_optimizer = SGDW(learning_rate=xargs.w_lr_max, weight_decay=xargs.w_weight_decay, momentum=xargs.w_momentum, nesterov=True)
|
|
||||||
a_optimizer = AdamW(learning_rate=xargs.arch_learning_rate, weight_decay=xargs.arch_weight_decay, beta_1=0.5, beta_2=0.999, epsilon=1e-07)
|
|
||||||
#w_optimizer = tf.keras.optimizers.SGD(learning_rate=0.025, momentum=0.9, nesterov=True)
|
|
||||||
#a_optimizer = tf.keras.optimizers.AdamW(learning_rate=xargs.arch_learning_rate, beta_1=0.5, beta_2=0.999, epsilon=1e-07)
|
|
||||||
####
|
|
||||||
# metrics
|
|
||||||
train_loss = tf.keras.metrics.Mean(name='train_loss')
|
|
||||||
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
|
|
||||||
valid_loss = tf.keras.metrics.Mean(name='valid_loss')
|
|
||||||
valid_accuracy = tf.keras.metrics.CategoricalAccuracy(name='valid_accuracy')
|
|
||||||
test_loss = tf.keras.metrics.Mean(name='test_loss')
|
|
||||||
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
|
|
||||||
|
|
||||||
@tf.function
|
|
||||||
def search_step(train_images, train_labels, valid_images, valid_labels):
|
|
||||||
# optimize weights
|
|
||||||
with tf.GradientTape() as tape:
|
|
||||||
predictions = model(train_images, True)
|
|
||||||
w_loss = loss_object(train_labels, predictions)
|
|
||||||
net_w_param = model.get_weights()
|
|
||||||
gradients = tape.gradient(w_loss, net_w_param)
|
|
||||||
w_optimizer.apply_gradients(zip(gradients, net_w_param))
|
|
||||||
train_loss(w_loss)
|
|
||||||
train_accuracy(train_labels, predictions)
|
|
||||||
# optimize alphas
|
|
||||||
with tf.GradientTape() as tape:
|
|
||||||
predictions = model(valid_images, True)
|
|
||||||
a_loss = loss_object(valid_labels, predictions)
|
|
||||||
net_a_param = model.get_alphas()
|
|
||||||
gradients = tape.gradient(a_loss, net_a_param)
|
|
||||||
a_optimizer.apply_gradients(zip(gradients, net_a_param))
|
|
||||||
valid_loss(a_loss)
|
|
||||||
valid_accuracy(valid_labels, predictions)
|
|
||||||
|
|
||||||
# IFT with Neumann approximation
|
|
||||||
@tf.function
|
|
||||||
def search_step_IFTNA(train_images, train_labels, valid_images, valid_labels, max_step):
|
|
||||||
# optimize weights
|
|
||||||
with tf.GradientTape() as tape:
|
|
||||||
predictions = model(train_images, True)
|
|
||||||
w_loss = loss_object(train_labels, predictions)
|
|
||||||
# get the weights
|
|
||||||
net_w_param = model.get_weights()
|
|
||||||
net_a_param = model.get_alphas()
|
|
||||||
gradients = tape.gradient(w_loss, net_w_param)
|
|
||||||
w_optimizer.apply_gradients(zip(gradients, net_w_param))
|
|
||||||
train_loss(w_loss)
|
|
||||||
train_accuracy(train_labels, predictions)
|
|
||||||
# optimize alphas
|
|
||||||
with tf.GradientTape(persistent=True) as tape:
|
|
||||||
predictions = model(valid_images, True)
|
|
||||||
val_loss = loss_object(valid_labels, predictions)
|
|
||||||
predictions = model(train_images, True)
|
|
||||||
trn_loss = loss_object(train_labels, predictions)
|
|
||||||
# ----
|
|
||||||
dV_dW = tape.gradient(val_loss, net_w_param)
|
|
||||||
# approxInverseHVP to calculate v2
|
|
||||||
sum_p = v1 = dV_dW
|
|
||||||
dT_dW = tape.gradient(trn_loss, net_w_param)
|
|
||||||
for j in range(1, max_step):
|
|
||||||
temp_dot = tape.gradient(dT_dW, net_w_param, output_gradients=v1)
|
|
||||||
v1 = [tf.subtract(A, B) for A, B in zip(v1, temp_dot)]
|
|
||||||
sum_p = [tf.add(A, B) for A, B in zip(sum_p, v1)]
|
|
||||||
# calculate v3
|
|
||||||
dT_dl = tape.gradient(trn_loss, net_a_param)
|
|
||||||
import pdb; pdb.set_trace()
|
|
||||||
v3 = tape.gradient(dT_dl, net_w_param, output_gradients=sum_p)
|
|
||||||
dV_dl = tape.gradient(val_loss, net_a_param)
|
|
||||||
a_gradients = [tf.subtract(A, B) for A, B in zip(dV_dl, v3)]
|
|
||||||
import pdb; pdb.set_trace()
|
|
||||||
print('--')
|
|
||||||
|
|
||||||
# TEST
|
|
||||||
@tf.function
|
|
||||||
def test_step(images, labels):
|
|
||||||
predictions = model(images)
|
|
||||||
t_loss = loss_object(labels, predictions)
|
|
||||||
|
|
||||||
test_loss(t_loss)
|
|
||||||
test_accuracy(labels, predictions)
|
|
||||||
|
|
||||||
print('{:} start searching with {:} epochs ({:} batches per epoch).'.format(time_string(), xargs.epochs, num_iters_per_epoch))
|
|
||||||
|
|
||||||
for epoch in range(xargs.epochs):
|
|
||||||
# Reset the metrics at the start of the next epoch
|
|
||||||
train_loss.reset_states() ; train_accuracy.reset_states()
|
|
||||||
test_loss.reset_states() ; test_accuracy.reset_states()
|
|
||||||
cur_lr = lr_schedular.get_lr(epoch)
|
|
||||||
tf.keras.backend.set_value(w_optimizer.lr, cur_lr)
|
|
||||||
|
|
||||||
for trn_imgs, trn_labels, val_imgs, val_labels in search_ds:
|
|
||||||
#search_step(trn_imgs, trn_labels, val_imgs, val_labels)
|
|
||||||
trn_labels, val_labels = tf.one_hot(trn_labels, 10), tf.one_hot(val_labels, 10)
|
|
||||||
search_step_IFTNA(trn_imgs, trn_labels, val_imgs, val_labels, 5)
|
|
||||||
genotype = model.genotype()
|
|
||||||
genotype = CellStructure(genotype)
|
|
||||||
|
|
||||||
#for test_images, test_labels in test_ds:
|
|
||||||
# test_step(test_images, test_labels)
|
|
||||||
|
|
||||||
cur_lr = float(tf.keras.backend.get_value(w_optimizer.lr))
|
|
||||||
template = '{:} Epoch {:03d}/{:03d}, Train-Loss: {:.3f}, Train-Accuracy: {:.2f}%, Valid-Loss: {:.3f}, Valid-Accuracy: {:.2f}% | lr={:.6f}'
|
|
||||||
print(template.format(time_string(), epoch+1, xargs.epochs,
|
|
||||||
train_loss.result(),
|
|
||||||
train_accuracy.result()*100,
|
|
||||||
valid_loss.result(),
|
|
||||||
valid_accuracy.result()*100,
|
|
||||||
cur_lr))
|
|
||||||
print('{:} genotype : {:}\n{:}\n'.format(time_string(), genotype, model.get_np_alphas()))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser(description='NAS-Bench-201', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
# training details
|
|
||||||
parser.add_argument('--epochs' , type=int , default= 250 , help='')
|
|
||||||
parser.add_argument('--w_lr_max' , type=float, default= 0.025, help='')
|
|
||||||
parser.add_argument('--w_lr_min' , type=float, default= 0.001, help='')
|
|
||||||
parser.add_argument('--w_weight_decay' , type=float, default=0.0005, help='')
|
|
||||||
parser.add_argument('--w_momentum' , type=float, default= 0.9 , help='')
|
|
||||||
parser.add_argument('--arch_learning_rate', type=float, default=0.0003, help='')
|
|
||||||
parser.add_argument('--arch_weight_decay' , type=float, default=0.001, help='')
|
|
||||||
# marco structure
|
|
||||||
parser.add_argument('--channel' , type=int , default=16, help='')
|
|
||||||
parser.add_argument('--num_cells' , type=int , default= 5, help='')
|
|
||||||
parser.add_argument('--max_nodes' , type=int , default= 4, help='')
|
|
||||||
args = parser.parse_args()
|
|
||||||
main( args )
|
|
@ -170,10 +170,28 @@ class NASBench201API(object):
|
|||||||
return archresult.get_comput_costs(dataset)
|
return archresult.get_comput_costs(dataset)
|
||||||
|
|
||||||
# obtain the metric for the `index`-th architecture
|
# obtain the metric for the `index`-th architecture
|
||||||
|
# `dataset` indicates the dataset:
|
||||||
|
# 'cifar10-valid' : using the proposed train set of CIFAR-10 as the training set
|
||||||
|
# 'cifar10' : using the proposed train+valid set of CIFAR-10 as the training set
|
||||||
|
# 'cifar100' : using the proposed train set of CIFAR-100 as the training set
|
||||||
|
# 'ImageNet16-120' : using the proposed train set of ImageNet-16-120 as the training set
|
||||||
|
# `iepoch` indicates the index of training epochs from 0 to 11/199.
|
||||||
|
# When iepoch=None, it will return the metric for the last training epoch
|
||||||
|
# When iepoch=11, it will return the metric for the 11-th training epoch (starting from 0)
|
||||||
|
# `use_12epochs_result` indicates different hyper-parameters for training
|
||||||
|
# When use_12epochs_result=True, it trains the network with 12 epochs and the LR decayed from 0.1 to 0 within 12 epochs
|
||||||
|
# When use_12epochs_result=False, it trains the network with 200 epochs and the LR decayed from 0.1 to 0 within 200 epochs
|
||||||
|
# `is_random`
|
||||||
|
# When is_random=True, the performance of a random architecture will be returned
|
||||||
|
# When is_random=False, the performanceo of all trials will be averaged.
|
||||||
def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
|
def get_more_info(self, index, dataset, iepoch=None, use_12epochs_result=False, is_random=True):
|
||||||
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
if use_12epochs_result: basestr, arch2infos = '12epochs' , self.arch2infos_less
|
||||||
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
else : basestr, arch2infos = '200epochs', self.arch2infos_full
|
||||||
archresult = arch2infos[index]
|
archresult = arch2infos[index]
|
||||||
|
# if randomly select one trial, select the seed at first
|
||||||
|
if isinstance(is_random, bool) and is_random:
|
||||||
|
seeds = archresult.get_dataset_seeds(dataset)
|
||||||
|
is_random = random.choice(seeds)
|
||||||
if dataset == 'cifar10-valid':
|
if dataset == 'cifar10-valid':
|
||||||
train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=is_random)
|
train_info = archresult.get_metrics(dataset, 'train' , iepoch=iepoch, is_random=is_random)
|
||||||
valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=is_random)
|
valid_info = archresult.get_metrics(dataset, 'x-valid' , iepoch=iepoch, is_random=is_random)
|
||||||
@ -202,7 +220,7 @@ class NASBench201API(object):
|
|||||||
else:
|
else:
|
||||||
test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
|
test__info = archresult.get_metrics(dataset, 'x-test', iepoch=iepoch, is_random=is_random)
|
||||||
except:
|
except:
|
||||||
valid_info = None
|
test__info = None
|
||||||
try:
|
try:
|
||||||
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
valid_info = archresult.get_metrics(dataset, 'x-valid', iepoch=iepoch, is_random=is_random)
|
||||||
except:
|
except:
|
||||||
@ -213,7 +231,7 @@ class NASBench201API(object):
|
|||||||
est_valid_info = None
|
est_valid_info = None
|
||||||
xifo = {'train-loss' : train_info['loss'],
|
xifo = {'train-loss' : train_info['loss'],
|
||||||
'train-accuracy': train_info['accuracy']}
|
'train-accuracy': train_info['accuracy']}
|
||||||
if valid_info is not None:
|
if test__info is not None:
|
||||||
xifo['test-loss'] = test__info['loss'],
|
xifo['test-loss'] = test__info['loss'],
|
||||||
xifo['test-accuracy'] = test__info['accuracy']
|
xifo['test-accuracy'] = test__info['accuracy']
|
||||||
if valid_info is not None:
|
if valid_info is not None:
|
||||||
@ -347,14 +365,20 @@ class ArchResults(object):
|
|||||||
info = result.get_eval(setname, iepoch)
|
info = result.get_eval(setname, iepoch)
|
||||||
for key, value in info.items(): infos[key].append( value )
|
for key, value in info.items(): infos[key].append( value )
|
||||||
return_info = dict()
|
return_info = dict()
|
||||||
if is_random:
|
if isinstance(is_random, bool) and is_random: # randomly select one
|
||||||
index = random.randint(0, len(results)-1)
|
index = random.randint(0, len(results)-1)
|
||||||
for key, value in infos.items(): return_info[key] = value[index]
|
for key, value in infos.items(): return_info[key] = value[index]
|
||||||
else:
|
elif isinstance(is_random, bool) and not is_random: # average
|
||||||
for key, value in infos.items():
|
for key, value in infos.items():
|
||||||
if len(value) > 0 and value[0] is not None:
|
if len(value) > 0 and value[0] is not None:
|
||||||
return_info[key] = np.mean(value)
|
return_info[key] = np.mean(value)
|
||||||
else: return_info[key] = None
|
else: return_info[key] = None
|
||||||
|
elif isinstance(is_random, int): # specify the seed
|
||||||
|
if is_random not in x_seeds: raise ValueError('can not find random seed ({:}) from {:}'.format(is_random, x_seeds))
|
||||||
|
index = x_seeds.index(is_random)
|
||||||
|
for key, value in infos.items(): return_info[key] = value[index]
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid value for is_random: {:}'.format(is_random))
|
||||||
return return_info
|
return return_info
|
||||||
|
|
||||||
def show(self, is_print=False):
|
def show(self, is_print=False):
|
||||||
@ -363,6 +387,9 @@ class ArchResults(object):
|
|||||||
def get_dataset_names(self):
|
def get_dataset_names(self):
|
||||||
return list(self.dataset_seed.keys())
|
return list(self.dataset_seed.keys())
|
||||||
|
|
||||||
|
def get_dataset_seeds(self, dataset):
|
||||||
|
return copy.deepcopy( self.dataset_seed[dataset] )
|
||||||
|
|
||||||
def get_net_param(self, dataset, seed=None):
|
def get_net_param(self, dataset, seed=None):
|
||||||
if seed is None:
|
if seed is None:
|
||||||
x_seeds = self.dataset_seed[dataset]
|
x_seeds = self.dataset_seed[dataset]
|
||||||
|
Loading…
Reference in New Issue
Block a user