update TF models (beta version)
This commit is contained in:
parent
e6ca3628ce
commit
5ac5060a33
7
configs/archs/CIFAR-SIM05.config
Normal file
7
configs/archs/CIFAR-SIM05.config
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"dataset" : ["str", "cifar"],
|
||||||
|
"arch" : ["str", "simres"],
|
||||||
|
"depth" : ["int", 5],
|
||||||
|
"super_type": ["str" , "basic"],
|
||||||
|
"zero_init_residual" : ["bool", "0"]
|
||||||
|
}
|
144
exps-tf/GDAS.py
Normal file
144
exps-tf/GDAS.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
# CUDA_VISIBLE_DEVICES=0 python exps-tf/GDAS.py
|
||||||
|
import os, sys, time, random, argparse
|
||||||
|
import tensorflow as tf
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||||||
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||||
|
|
||||||
|
# self-lib
|
||||||
|
from tf_models import get_cell_based_tiny_net
|
||||||
|
from tf_optimizers import SGDW, AdamW
|
||||||
|
from config_utils import dict2config
|
||||||
|
from log_utils import time_string
|
||||||
|
from models import CellStructure
|
||||||
|
|
||||||
|
|
||||||
|
def pre_process(image_a, label_a, image_b, label_b):
|
||||||
|
def standard_func(image):
|
||||||
|
x = tf.pad(image, [[4, 4], [4, 4], [0, 0]])
|
||||||
|
x = tf.image.random_crop(x, [32, 32, 3])
|
||||||
|
x = tf.image.random_flip_left_right(x)
|
||||||
|
return x
|
||||||
|
return standard_func(image_a), label_a, standard_func(image_b), label_b
|
||||||
|
|
||||||
|
|
||||||
|
def main(xargs):
|
||||||
|
cifar10 = tf.keras.datasets.cifar10
|
||||||
|
|
||||||
|
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
|
||||||
|
x_train, x_test = x_train / 255.0, x_test / 255.0
|
||||||
|
x_train, x_test = x_train.astype('float32'), x_test.astype('float32')
|
||||||
|
|
||||||
|
# Add a channels dimension
|
||||||
|
all_indexes = list(range(x_train.shape[0]))
|
||||||
|
random.shuffle(all_indexes)
|
||||||
|
s_train_idxs, s_valid_idxs = all_indexes[::2], all_indexes[1::2]
|
||||||
|
search_train_x, search_train_y = x_train[s_train_idxs], y_train[s_train_idxs]
|
||||||
|
search_valid_x, search_valid_y = x_train[s_valid_idxs], y_train[s_valid_idxs]
|
||||||
|
#x_train, x_test = x_train[..., tf.newaxis], x_test[..., tf.newaxis]
|
||||||
|
|
||||||
|
# Use tf.data
|
||||||
|
#train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(64)
|
||||||
|
search_ds = tf.data.Dataset.from_tensor_slices((search_train_x, search_train_y, search_valid_x, search_valid_y))
|
||||||
|
search_ds = search_ds.map(pre_process).shuffle(1000).batch(64)
|
||||||
|
|
||||||
|
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
|
||||||
|
|
||||||
|
# Create an instance of the model
|
||||||
|
config = dict2config({'name': 'GDAS',
|
||||||
|
'C' : xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes,
|
||||||
|
'num_classes': 10, 'space': 'nas-bench-102', 'affine': True}, None)
|
||||||
|
model = get_cell_based_tiny_net(config)
|
||||||
|
#import pdb; pdb.set_trace()
|
||||||
|
#model.build(((64, 32, 32, 3), (1,)))
|
||||||
|
#for x in model.trainable_variables:
|
||||||
|
# print('{:30s} : {:}'.format(x.name, x.shape))
|
||||||
|
# Choose optimizer
|
||||||
|
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||||
|
w_optimizer = SGDW(learning_rate=xargs.w_lr, weight_decay=xargs.w_weight_decay, momentum=xargs.w_momentum, nesterov=True)
|
||||||
|
a_optimizer = AdamW(learning_rate=xargs.arch_learning_rate, weight_decay=xargs.arch_weight_decay, beta_1=0.5, beta_2=0.999, epsilon=1e-07)
|
||||||
|
#w_optimizer = tf.keras.optimizers.SGD(learning_rate=0.025, momentum=0.9, nesterov=True)
|
||||||
|
#a_optimizer = tf.keras.optimizers.AdamW(learning_rate=xargs.arch_learning_rate, beta_1=0.5, beta_2=0.999, epsilon=1e-07)
|
||||||
|
####
|
||||||
|
# metrics
|
||||||
|
train_loss = tf.keras.metrics.Mean(name='train_loss')
|
||||||
|
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
|
||||||
|
valid_loss = tf.keras.metrics.Mean(name='valid_loss')
|
||||||
|
valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='valid_accuracy')
|
||||||
|
test_loss = tf.keras.metrics.Mean(name='test_loss')
|
||||||
|
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def search_step(train_images, train_labels, valid_images, valid_labels, tf_tau):
|
||||||
|
# optimize weights
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
predictions = model(train_images, tf_tau, True)
|
||||||
|
w_loss = loss_object(train_labels, predictions)
|
||||||
|
net_w_param = model.get_weights()
|
||||||
|
gradients = tape.gradient(w_loss, net_w_param)
|
||||||
|
w_optimizer.apply_gradients(zip(gradients, net_w_param))
|
||||||
|
train_loss(w_loss)
|
||||||
|
train_accuracy(train_labels, predictions)
|
||||||
|
# optimize alphas
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
predictions = model(valid_images, tf_tau, True)
|
||||||
|
a_loss = loss_object(valid_labels, predictions)
|
||||||
|
net_a_param = model.get_alphas()
|
||||||
|
gradients = tape.gradient(a_loss, net_a_param)
|
||||||
|
a_optimizer.apply_gradients(zip(gradients, net_a_param))
|
||||||
|
valid_loss(a_loss)
|
||||||
|
valid_accuracy(valid_labels, predictions)
|
||||||
|
|
||||||
|
# TEST
|
||||||
|
@tf.function
|
||||||
|
def test_step(images, labels):
|
||||||
|
predictions = model(images)
|
||||||
|
t_loss = loss_object(labels, predictions)
|
||||||
|
|
||||||
|
test_loss(t_loss)
|
||||||
|
test_accuracy(labels, predictions)
|
||||||
|
|
||||||
|
print('{:} start searching with {:} epochs ({:} batches per epoch).'.format(time_string(), xargs.epochs, tf.data.experimental.cardinality(search_ds).numpy()))
|
||||||
|
|
||||||
|
for epoch in range(xargs.epochs):
|
||||||
|
# Reset the metrics at the start of the next epoch
|
||||||
|
train_loss.reset_states() ; train_accuracy.reset_states()
|
||||||
|
test_loss.reset_states() ; test_accuracy.reset_states()
|
||||||
|
cur_tau = xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (xargs.epochs-1)
|
||||||
|
tf_tau = tf.cast(cur_tau, dtype=tf.float32, name='tau')
|
||||||
|
|
||||||
|
for trn_imgs, trn_labels, val_imgs, val_labels in search_ds:
|
||||||
|
search_step(trn_imgs, trn_labels, val_imgs, val_labels, tf_tau)
|
||||||
|
genotype = model.genotype()
|
||||||
|
genotype = CellStructure(genotype)
|
||||||
|
|
||||||
|
#for test_images, test_labels in test_ds:
|
||||||
|
# test_step(test_images, test_labels)
|
||||||
|
|
||||||
|
template = '{:} Epoch {:03d}/{:03d}, Train-Loss: {:.3f}, Train-Accuracy: {:.2f}%, Valid-Loss: {:.3f}, Valid-Accuracy: {:.2f}% | tau={:.3f}'
|
||||||
|
print(template.format(time_string(), epoch+1, xargs.epochs,
|
||||||
|
train_loss.result(),
|
||||||
|
train_accuracy.result()*100,
|
||||||
|
valid_loss.result(),
|
||||||
|
valid_accuracy.result()*100,
|
||||||
|
cur_tau))
|
||||||
|
print('{:} genotype : {:}\n{:}\n'.format(time_string(), genotype, model.get_np_alphas()))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
# training details
|
||||||
|
parser.add_argument('--epochs' , type=int , default= 250 , help='')
|
||||||
|
parser.add_argument('--tau_max' , type=float, default= 10 , help='')
|
||||||
|
parser.add_argument('--tau_min' , type=float, default= 0.1 , help='')
|
||||||
|
parser.add_argument('--w_lr' , type=float, default= 0.025, help='')
|
||||||
|
parser.add_argument('--w_weight_decay' , type=float, default=0.0005, help='')
|
||||||
|
parser.add_argument('--w_momentum' , type=float, default= 0.9 , help='')
|
||||||
|
parser.add_argument('--arch_learning_rate', type=float, default=0.0003, help='')
|
||||||
|
parser.add_argument('--arch_weight_decay' , type=float, default=0.001, help='')
|
||||||
|
# marco structure
|
||||||
|
parser.add_argument('--channel' , type=int , default=16, help='')
|
||||||
|
parser.add_argument('--num_cells' , type=int , default= 5, help='')
|
||||||
|
parser.add_argument('--max_nodes' , type=int , default= 4, help='')
|
||||||
|
args = parser.parse_args()
|
||||||
|
main( args )
|
@ -6,7 +6,6 @@ import numpy as np
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||||
from graphviz import Digraph
|
|
||||||
|
|
||||||
|
|
||||||
def test_nas_api():
|
def test_nas_api():
|
||||||
@ -29,6 +28,7 @@ OPS = ['skip-connect', 'conv-1x1', 'conv-3x3', 'pool-3x3']
|
|||||||
COLORS = ['chartreuse' , 'cyan' , 'navyblue', 'chocolate1']
|
COLORS = ['chartreuse' , 'cyan' , 'navyblue', 'chocolate1']
|
||||||
|
|
||||||
def plot(filename):
|
def plot(filename):
|
||||||
|
from graphviz import Digraph
|
||||||
g = Digraph(
|
g = Digraph(
|
||||||
format='png',
|
format='png',
|
||||||
edge_attr=dict(fontsize='20', fontname="times"),
|
edge_attr=dict(fontsize='20', fontname="times"),
|
||||||
@ -53,6 +53,26 @@ def plot(filename):
|
|||||||
g.render(filename, cleanup=True, view=False)
|
g.render(filename, cleanup=True, view=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_grad():
|
||||||
|
class Net(torch.nn.Module):
|
||||||
|
def __init__(self, iS):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.layer = torch.nn.Linear(iS, 1)
|
||||||
|
def forward(self, inputs):
|
||||||
|
outputs = self.layer(inputs)
|
||||||
|
outputs = torch.exp(outputs)
|
||||||
|
return outputs.mean()
|
||||||
|
net = Net(10)
|
||||||
|
inputs = torch.rand(256, 10)
|
||||||
|
loss = net(inputs)
|
||||||
|
first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True)
|
||||||
|
first_order_grads = torch.cat([x.view(-1) for x in first_order_grads])
|
||||||
|
second_order_grads = []
|
||||||
|
for grads in first_order_grads:
|
||||||
|
s_grads = torch.autograd.grad(grads, net.parameters())
|
||||||
|
second_order_grads.append( s_grads )
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_nas_api()
|
#test_nas_api()
|
||||||
for i in range(200): plot('{:04d}'.format(i))
|
#for i in range(200): plot('{:04d}'.format(i))
|
||||||
|
test_auto_grad()
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
##################################################
|
##################################################
|
||||||
from .logger import Logger
|
# every package does not rely on pytorch or tensorflow
|
||||||
from .print_logger import PrintLogger
|
# I tried to list all dependency here: os, sys, time, numpy, (possibly) matplotlib
|
||||||
|
from .logger import Logger, PrintLogger
|
||||||
from .meter import AverageMeter
|
from .meter import AverageMeter
|
||||||
from .time_utils import time_for_file, time_string, time_string_short, time_print, convert_size2str, convert_secs2time
|
from .time_utils import time_for_file, time_string, time_string_short, time_print, convert_secs2time
|
||||||
|
@ -1,9 +1,6 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
##################################################
|
||||||
# All rights reserved.
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
#
|
##################################################
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
#
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import importlib, warnings
|
import importlib, warnings
|
||||||
import os, sys, time, numpy as np
|
import os, sys, time, numpy as np
|
||||||
@ -16,6 +13,19 @@ if importlib.util.find_spec('tensorflow'):
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class PrintLogger(object):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Create a summary writer logging to log_dir."""
|
||||||
|
self.name = 'PrintLogger'
|
||||||
|
|
||||||
|
def log(self, string):
|
||||||
|
print (string)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
print ('-'*30 + ' close printer ' + '-'*30)
|
||||||
|
|
||||||
|
|
||||||
class Logger(object):
|
class Logger(object):
|
||||||
|
|
||||||
def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False):
|
def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False):
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import time, sys
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,14 +0,0 @@
|
|||||||
import os, sys, time
|
|
||||||
|
|
||||||
|
|
||||||
class PrintLogger(object):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Create a summary writer logging to log_dir."""
|
|
||||||
self.name = 'PrintLogger'
|
|
||||||
|
|
||||||
def log(self, string):
|
|
||||||
print (string)
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
print ('-'*30 + ' close printer ' + '-'*30)
|
|
@ -1,37 +1,27 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
##################################################
|
||||||
# All rights reserved.
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
#
|
##################################################
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
#
|
|
||||||
import time, sys
|
import time, sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
def time_for_file():
|
def time_for_file():
|
||||||
ISOTIMEFORMAT='%d-%h-at-%H-%M-%S'
|
ISOTIMEFORMAT='%d-%h-at-%H-%M-%S'
|
||||||
return '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
return '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
||||||
|
|
||||||
def time_string():
|
def time_string():
|
||||||
ISOTIMEFORMAT='%Y-%m-%d %X'
|
ISOTIMEFORMAT='%Y-%m-%d %X'
|
||||||
string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
string = '[{:}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
||||||
return string
|
return string
|
||||||
|
|
||||||
def time_string_short():
|
def time_string_short():
|
||||||
ISOTIMEFORMAT='%Y%m%d'
|
ISOTIMEFORMAT='%Y%m%d'
|
||||||
string = '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
string = '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
||||||
return string
|
return string
|
||||||
|
|
||||||
def time_print(string, is_print=True):
|
def time_print(string, is_print=True):
|
||||||
if (is_print):
|
if (is_print):
|
||||||
print('{} : {}'.format(time_string(), string))
|
print('{} : {}'.format(time_string(), string))
|
||||||
|
|
||||||
def convert_size2str(torch_size):
|
|
||||||
dims = len(torch_size)
|
|
||||||
string = '['
|
|
||||||
for idim in range(dims):
|
|
||||||
string = string + ' {}'.format(torch_size[idim])
|
|
||||||
return string + ']'
|
|
||||||
|
|
||||||
def convert_secs2time(epoch_time, return_str=False):
|
def convert_secs2time(epoch_time, return_str=False):
|
||||||
need_hour = int(epoch_time / 3600)
|
need_hour = int(epoch_time / 3600)
|
||||||
need_mins = int((epoch_time - 3600*need_hour) / 60)
|
need_mins = int((epoch_time - 3600*need_hour) / 60)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
##################################################
|
##################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
##################################################
|
##################################################
|
||||||
import torch
|
|
||||||
from os import path as osp
|
from os import path as osp
|
||||||
|
|
||||||
__all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \
|
__all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \
|
||||||
@ -126,6 +125,11 @@ def obtain_search_model(config):
|
|||||||
elif config.search_mode == 'shape':
|
elif config.search_mode == 'shape':
|
||||||
return SearchShapeCifarResNet(config.module, config.depth, config.class_num)
|
return SearchShapeCifarResNet(config.module, config.depth, config.class_num)
|
||||||
else: raise ValueError('invalid search mode : {:}'.format(config.search_mode))
|
else: raise ValueError('invalid search mode : {:}'.format(config.search_mode))
|
||||||
|
elif config.arch == 'simres':
|
||||||
|
from .shape_searchs import SearchWidthSimResNet
|
||||||
|
if config.search_mode == 'width':
|
||||||
|
return SearchWidthSimResNet(config.depth, config.class_num)
|
||||||
|
else: raise ValueError('invalid search mode : {:}'.format(config.search_mode))
|
||||||
else:
|
else:
|
||||||
raise ValueError('invalid arch : {:} for dataset [{:}]'.format(config.arch, config.dataset))
|
raise ValueError('invalid arch : {:} for dataset [{:}]'.format(config.arch, config.dataset))
|
||||||
elif config.dataset == 'imagenet':
|
elif config.dataset == 'imagenet':
|
||||||
@ -140,6 +144,7 @@ def obtain_search_model(config):
|
|||||||
|
|
||||||
|
|
||||||
def load_net_from_checkpoint(checkpoint):
|
def load_net_from_checkpoint(checkpoint):
|
||||||
|
import torch
|
||||||
assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint)
|
assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint)
|
||||||
checkpoint = torch.load(checkpoint)
|
checkpoint = torch.load(checkpoint)
|
||||||
model_config = dict2config(checkpoint['model-config'], None)
|
model_config = dict2config(checkpoint['model-config'], None)
|
||||||
|
316
lib/models/shape_searchs/SearchSimResNet_width.py
Normal file
316
lib/models/shape_searchs/SearchSimResNet_width.py
Normal file
@ -0,0 +1,316 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
##################################################
|
||||||
|
import math, torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from ..initialization import initialize_resnet
|
||||||
|
from ..SharedUtils import additive_func
|
||||||
|
from .SoftSelect import select2withP, ChannelWiseInter
|
||||||
|
from .SoftSelect import linear_forward
|
||||||
|
from .SoftSelect import get_width_choices as get_choices
|
||||||
|
|
||||||
|
|
||||||
|
def conv_forward(inputs, conv, choices):
|
||||||
|
iC = conv.in_channels
|
||||||
|
fill_size = list(inputs.size())
|
||||||
|
fill_size[1] = iC - fill_size[1]
|
||||||
|
filled = torch.zeros(fill_size, device=inputs.device)
|
||||||
|
xinputs = torch.cat((inputs, filled), dim=1)
|
||||||
|
outputs = conv(xinputs)
|
||||||
|
selecteds = [outputs[:,:oC] for oC in choices]
|
||||||
|
return selecteds
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBNReLU(nn.Module):
|
||||||
|
num_conv = 1
|
||||||
|
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
|
||||||
|
super(ConvBNReLU, self).__init__()
|
||||||
|
self.InShape = None
|
||||||
|
self.OutShape = None
|
||||||
|
self.choices = get_choices(nOut)
|
||||||
|
self.register_buffer('choices_tensor', torch.Tensor( self.choices ))
|
||||||
|
|
||||||
|
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
|
||||||
|
else : self.avg = None
|
||||||
|
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
|
||||||
|
#if has_bn : self.bn = nn.BatchNorm2d(nOut)
|
||||||
|
#else : self.bn = None
|
||||||
|
self.has_bn = has_bn
|
||||||
|
self.BNs = nn.ModuleList()
|
||||||
|
for i, _out in enumerate(self.choices):
|
||||||
|
self.BNs.append(nn.BatchNorm2d(_out))
|
||||||
|
if has_relu: self.relu = nn.ReLU(inplace=True)
|
||||||
|
else : self.relu = None
|
||||||
|
self.in_dim = nIn
|
||||||
|
self.out_dim = nOut
|
||||||
|
self.search_mode = 'basic'
|
||||||
|
|
||||||
|
def get_flops(self, channels, check_range=True, divide=1):
|
||||||
|
iC, oC = channels
|
||||||
|
if check_range: assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels)
|
||||||
|
assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape)
|
||||||
|
assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape)
|
||||||
|
#conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
|
||||||
|
conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups)
|
||||||
|
all_positions = self.OutShape[0] * self.OutShape[1]
|
||||||
|
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
|
||||||
|
if self.conv.bias is not None: flops += all_positions / divide
|
||||||
|
return flops
|
||||||
|
|
||||||
|
def get_range(self):
|
||||||
|
return [self.choices]
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
if self.search_mode == 'basic':
|
||||||
|
return self.basic_forward(inputs)
|
||||||
|
elif self.search_mode == 'search':
|
||||||
|
return self.search_forward(inputs)
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
|
||||||
|
|
||||||
|
def search_forward(self, tuple_inputs):
|
||||||
|
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
|
||||||
|
inputs, expected_inC, probability, index, prob = tuple_inputs
|
||||||
|
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
|
||||||
|
probability = torch.squeeze(probability)
|
||||||
|
assert len(index) == 2, 'invalid length : {:}'.format(index)
|
||||||
|
# compute expected flop
|
||||||
|
#coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
|
||||||
|
expected_outC = (self.choices_tensor * probability).sum()
|
||||||
|
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
|
||||||
|
if self.avg : out = self.avg( inputs )
|
||||||
|
else : out = inputs
|
||||||
|
# convolutional layer
|
||||||
|
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
|
||||||
|
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
|
||||||
|
# merge
|
||||||
|
out_channel = max([x.size(1) for x in out_bns])
|
||||||
|
outA = ChannelWiseInter(out_bns[0], out_channel)
|
||||||
|
outB = ChannelWiseInter(out_bns[1], out_channel)
|
||||||
|
out = outA * prob[0] + outB * prob[1]
|
||||||
|
#out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
|
||||||
|
|
||||||
|
if self.relu: out = self.relu( out )
|
||||||
|
else : out = out
|
||||||
|
return out, expected_outC, expected_flop
|
||||||
|
|
||||||
|
def basic_forward(self, inputs):
|
||||||
|
if self.avg : out = self.avg( inputs )
|
||||||
|
else : out = inputs
|
||||||
|
conv = self.conv( out )
|
||||||
|
if self.has_bn:out= self.BNs[-1]( conv )
|
||||||
|
else : out = conv
|
||||||
|
if self.relu: out = self.relu( out )
|
||||||
|
else : out = out
|
||||||
|
if self.InShape is None:
|
||||||
|
self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||||
|
self.OutShape = (out.size(-2) , out.size(-1))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SimBlock(nn.Module):
|
||||||
|
expansion = 1
|
||||||
|
num_conv = 1
|
||||||
|
def __init__(self, inplanes, planes, stride):
|
||||||
|
super(SimBlock, self).__init__()
|
||||||
|
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||||
|
self.conv = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
|
||||||
|
if stride == 2:
|
||||||
|
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
|
||||||
|
elif inplanes != planes:
|
||||||
|
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
|
||||||
|
else:
|
||||||
|
self.downsample = None
|
||||||
|
self.out_dim = planes
|
||||||
|
self.search_mode = 'basic'
|
||||||
|
|
||||||
|
def get_range(self):
|
||||||
|
return self.conv.get_range()
|
||||||
|
|
||||||
|
def get_flops(self, channels):
|
||||||
|
assert len(channels) == 2, 'invalid channels : {:}'.format(channels)
|
||||||
|
flop_A = self.conv.get_flops([channels[0], channels[1]])
|
||||||
|
if hasattr(self.downsample, 'get_flops'):
|
||||||
|
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
|
||||||
|
else:
|
||||||
|
flop_C = 0
|
||||||
|
if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train
|
||||||
|
flop_C = channels[0] * channels[-1] * self.conv.OutShape[0] * self.conv.OutShape[1]
|
||||||
|
return flop_A + flop_C
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
if self.search_mode == 'basic' : return self.basic_forward(inputs)
|
||||||
|
elif self.search_mode == 'search': return self.search_forward(inputs)
|
||||||
|
else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
|
||||||
|
|
||||||
|
def search_forward(self, tuple_inputs):
|
||||||
|
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
|
||||||
|
inputs, expected_inC, probability, indexes, probs = tuple_inputs
|
||||||
|
assert indexes.size(0) == 1 and probs.size(0) == 1 and probability.size(0) == 1, 'invalid size : {:}, {:}, {:}'.format(indexes.size(), probs.size(), probability.size())
|
||||||
|
out, expected_next_inC, expected_flop = self.conv( (inputs, expected_inC , probability[0], indexes[0], probs[0]) )
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[-1], indexes[-1], probs[-1]) )
|
||||||
|
else:
|
||||||
|
residual, expected_flop_c = inputs, 0
|
||||||
|
out = additive_func(residual, out)
|
||||||
|
return out, expected_next_inC, sum([expected_flop, expected_flop_c])
|
||||||
|
|
||||||
|
def basic_forward(self, inputs):
|
||||||
|
basicblock = self.conv(inputs)
|
||||||
|
if self.downsample is not None: residual = self.downsample(inputs)
|
||||||
|
else : residual = inputs
|
||||||
|
out = additive_func(residual, basicblock)
|
||||||
|
return nn.functional.relu(out, inplace=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SearchWidthSimResNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, depth, num_classes):
|
||||||
|
super(SearchWidthSimResNet, self).__init__()
|
||||||
|
|
||||||
|
assert (depth - 2) % 3 == 0, 'depth should be one of 5, 8, 11, 14, ... instead of {:}'.format(depth)
|
||||||
|
layer_blocks = (depth - 2) // 3
|
||||||
|
self.message = 'SearchWidthSimResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.channels = [16]
|
||||||
|
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
|
||||||
|
self.InShape = None
|
||||||
|
for stage in range(3):
|
||||||
|
for iL in range(layer_blocks):
|
||||||
|
iC = self.channels[-1]
|
||||||
|
planes = 16 * (2**stage)
|
||||||
|
stride = 2 if stage > 0 and iL == 0 else 1
|
||||||
|
module = SimBlock(iC, planes, stride)
|
||||||
|
self.channels.append( module.out_dim )
|
||||||
|
self.layers.append ( module )
|
||||||
|
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride)
|
||||||
|
|
||||||
|
self.avgpool = nn.AvgPool2d(8)
|
||||||
|
self.classifier = nn.Linear(module.out_dim, num_classes)
|
||||||
|
self.InShape = None
|
||||||
|
self.tau = -1
|
||||||
|
self.search_mode = 'basic'
|
||||||
|
#assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
|
||||||
|
|
||||||
|
# parameters for width
|
||||||
|
self.Ranges = []
|
||||||
|
self.layer2indexRange = []
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
start_index = len(self.Ranges)
|
||||||
|
self.Ranges += layer.get_range()
|
||||||
|
self.layer2indexRange.append( (start_index, len(self.Ranges)) )
|
||||||
|
assert len(self.Ranges) + 1 == depth, 'invalid depth check {:} vs {:}'.format(len(self.Ranges) + 1, depth)
|
||||||
|
|
||||||
|
self.register_parameter('width_attentions', nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))))
|
||||||
|
nn.init.normal_(self.width_attentions, 0, 0.01)
|
||||||
|
self.apply(initialize_resnet)
|
||||||
|
|
||||||
|
def arch_parameters(self):
|
||||||
|
return [self.width_attentions]
|
||||||
|
|
||||||
|
def base_parameters(self):
|
||||||
|
return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters())
|
||||||
|
|
||||||
|
def get_flop(self, mode, config_dict, extra_info):
|
||||||
|
if config_dict is not None: config_dict = config_dict.copy()
|
||||||
|
#weights = [F.softmax(x, dim=0) for x in self.width_attentions]
|
||||||
|
channels = [3]
|
||||||
|
for i, weight in enumerate(self.width_attentions):
|
||||||
|
if mode == 'genotype':
|
||||||
|
with torch.no_grad():
|
||||||
|
probe = nn.functional.softmax(weight, dim=0)
|
||||||
|
C = self.Ranges[i][ torch.argmax(probe).item() ]
|
||||||
|
elif mode == 'max':
|
||||||
|
C = self.Ranges[i][-1]
|
||||||
|
elif mode == 'fix':
|
||||||
|
C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] )
|
||||||
|
elif mode == 'random':
|
||||||
|
assert isinstance(extra_info, float), 'invalid extra_info : {:}'.format(extra_info)
|
||||||
|
with torch.no_grad():
|
||||||
|
prob = nn.functional.softmax(weight, dim=0)
|
||||||
|
approximate_C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] )
|
||||||
|
for j in range(prob.size(0)):
|
||||||
|
prob[j] = 1 / (abs(j - (approximate_C-self.Ranges[i][j])) + 0.2)
|
||||||
|
C = self.Ranges[i][ torch.multinomial(prob, 1, False).item() ]
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid mode : {:}'.format(mode))
|
||||||
|
channels.append( C )
|
||||||
|
flop = 0
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
s, e = self.layer2indexRange[i]
|
||||||
|
xchl = tuple( channels[s:e+1] )
|
||||||
|
flop+= layer.get_flops(xchl)
|
||||||
|
# the last fc layer
|
||||||
|
flop += channels[-1] * self.classifier.out_features
|
||||||
|
if config_dict is None:
|
||||||
|
return flop / 1e6
|
||||||
|
else:
|
||||||
|
config_dict['xchannels'] = channels
|
||||||
|
config_dict['super_type'] = 'infer-width'
|
||||||
|
config_dict['estimated_FLOP'] = flop / 1e6
|
||||||
|
return flop / 1e6, config_dict
|
||||||
|
|
||||||
|
def get_arch_info(self):
|
||||||
|
string = "for width, there are {:} attention probabilities.".format(len(self.width_attentions))
|
||||||
|
discrepancy = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, att in enumerate(self.width_attentions):
|
||||||
|
prob = nn.functional.softmax(att, dim=0)
|
||||||
|
prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist()
|
||||||
|
prob = ['{:.3f}'.format(x) for x in prob]
|
||||||
|
xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.width_attentions), ' '.join(prob))
|
||||||
|
logt = ['{:.3f}'.format(x) for x in att.cpu().tolist()]
|
||||||
|
xstring += ' || {:52s}'.format(' '.join(logt))
|
||||||
|
prob = sorted( [float(x) for x in prob] )
|
||||||
|
disc = prob[-1] - prob[-2]
|
||||||
|
xstring += ' || dis={:.2f} || select={:}/{:}'.format(disc, selc, len(prob))
|
||||||
|
discrepancy.append( disc )
|
||||||
|
string += '\n{:}'.format(xstring)
|
||||||
|
return string, discrepancy
|
||||||
|
|
||||||
|
def set_tau(self, tau_max, tau_min, epoch_ratio):
|
||||||
|
assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio)
|
||||||
|
tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
|
||||||
|
self.tau = tau
|
||||||
|
|
||||||
|
def get_message(self):
|
||||||
|
return self.message
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
if self.search_mode == 'basic':
|
||||||
|
return self.basic_forward(inputs)
|
||||||
|
elif self.search_mode == 'search':
|
||||||
|
return self.search_forward(inputs)
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
|
||||||
|
|
||||||
|
def search_forward(self, inputs):
|
||||||
|
flop_probs = nn.functional.softmax(self.width_attentions, dim=1)
|
||||||
|
selected_widths, selected_probs = select2withP(self.width_attentions, self.tau)
|
||||||
|
with torch.no_grad():
|
||||||
|
selected_widths = selected_widths.cpu()
|
||||||
|
|
||||||
|
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
selected_w_index = selected_widths[last_channel_idx: last_channel_idx+layer.num_conv]
|
||||||
|
selected_w_probs = selected_probs[last_channel_idx: last_channel_idx+layer.num_conv]
|
||||||
|
layer_prob = flop_probs[last_channel_idx: last_channel_idx+layer.num_conv]
|
||||||
|
x, expected_inC, expected_flop = layer( (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) )
|
||||||
|
last_channel_idx += layer.num_conv
|
||||||
|
flops.append( expected_flop )
|
||||||
|
flops.append( expected_inC * (self.classifier.out_features*1.0/1e6) )
|
||||||
|
features = self.avgpool(x)
|
||||||
|
features = features.view(features.size(0), -1)
|
||||||
|
logits = linear_forward(features, self.classifier)
|
||||||
|
return logits, torch.stack( [sum(flops)] )
|
||||||
|
|
||||||
|
def basic_forward(self, inputs):
|
||||||
|
if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1))
|
||||||
|
x = inputs
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
x = layer( x )
|
||||||
|
features = self.avgpool(x)
|
||||||
|
features = features.view(features.size(0), -1)
|
||||||
|
logits = self.classifier(features)
|
||||||
|
return features, logits
|
@ -4,4 +4,5 @@
|
|||||||
from .SearchCifarResNet_width import SearchWidthCifarResNet
|
from .SearchCifarResNet_width import SearchWidthCifarResNet
|
||||||
from .SearchCifarResNet_depth import SearchDepthCifarResNet
|
from .SearchCifarResNet_depth import SearchDepthCifarResNet
|
||||||
from .SearchCifarResNet import SearchShapeCifarResNet
|
from .SearchCifarResNet import SearchShapeCifarResNet
|
||||||
|
from .SearchSimResNet_width import SearchWidthSimResNet
|
||||||
from .SearchImagenetResNet import SearchShapeImagenetResNet
|
from .SearchImagenetResNet import SearchShapeImagenetResNet
|
||||||
|
32
lib/tf_models/__init__.py
Normal file
32
lib/tf_models/__init__.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
##################################################
|
||||||
|
import torch
|
||||||
|
from os import path as osp
|
||||||
|
|
||||||
|
__all__ = ['get_cell_based_tiny_net', 'get_search_spaces']
|
||||||
|
|
||||||
|
|
||||||
|
# the cell-based NAS models
|
||||||
|
def get_cell_based_tiny_net(config):
|
||||||
|
group_names = ['GDAS']
|
||||||
|
if config.name in group_names:
|
||||||
|
from .cell_searchs import nas_super_nets
|
||||||
|
from .cell_operations import SearchSpaceNames
|
||||||
|
if isinstance(config.space, str): search_space = SearchSpaceNames[config.space]
|
||||||
|
else: search_space = config.space
|
||||||
|
return nas_super_nets[config.name](
|
||||||
|
config.C, config.N, config.max_nodes,
|
||||||
|
config.num_classes, search_space, config.affine)
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid network name : {:}'.format(config.name))
|
||||||
|
|
||||||
|
|
||||||
|
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
|
||||||
|
def get_search_spaces(xtype, name):
|
||||||
|
if xtype == 'cell':
|
||||||
|
from .cell_operations import SearchSpaceNames
|
||||||
|
assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())
|
||||||
|
return SearchSpaceNames[name]
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid search-space type is {:}'.format(xtype))
|
120
lib/tf_models/cell_operations.py
Normal file
120
lib/tf_models/cell_operations.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
##################################################
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
|
||||||
|
|
||||||
|
OPS = {
|
||||||
|
'none' : lambda C_in, C_out, stride, affine: Zero(C_in, C_out, stride),
|
||||||
|
'avg_pool_3x3': lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg', affine),
|
||||||
|
'nor_conv_1x1': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 1, stride, affine),
|
||||||
|
'nor_conv_3x3': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 3, stride, affine),
|
||||||
|
'nor_conv_5x5': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 5, stride, affine),
|
||||||
|
'skip_connect': lambda C_in, C_out, stride, affine: Identity(C_in, C_out, stride)
|
||||||
|
}
|
||||||
|
|
||||||
|
NAS_BENCH_102 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
|
||||||
|
|
||||||
|
SearchSpaceNames = {
|
||||||
|
'nas-bench-102': NAS_BENCH_102,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class POOLING(tf.keras.layers.Layer):
|
||||||
|
|
||||||
|
def __init__(self, C_in, C_out, stride, mode, affine):
|
||||||
|
super(POOLING, self).__init__()
|
||||||
|
if C_in == C_out:
|
||||||
|
self.preprocess = None
|
||||||
|
else:
|
||||||
|
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, affine)
|
||||||
|
if mode == 'avg' : self.op = tf.keras.layers.AvgPool2D((3,3), strides=stride, padding='same')
|
||||||
|
elif mode == 'max': self.op = tf.keras.layers.MaxPool2D((3,3), strides=stride, padding='same')
|
||||||
|
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
|
||||||
|
|
||||||
|
def call(self, inputs, training):
|
||||||
|
if self.preprocess: x = self.preprocess(inputs)
|
||||||
|
else : x = inputs
|
||||||
|
return self.op(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(tf.keras.layers.Layer):
|
||||||
|
def __init__(self, C_in, C_out, stride):
|
||||||
|
super(Identity, self).__init__()
|
||||||
|
if C_in != C_out or stride != 1:
|
||||||
|
self.layer = tf.keras.layers.Conv2D(C_out, 3, stride, padding='same', use_bias=False)
|
||||||
|
else:
|
||||||
|
self.layer = None
|
||||||
|
|
||||||
|
def call(self, inputs, training):
|
||||||
|
x = inputs
|
||||||
|
if self.layer is not None:
|
||||||
|
x = self.layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Zero(tf.keras.layers.Layer):
|
||||||
|
def __init__(self, C_in, C_out, stride):
|
||||||
|
super(Zero, self).__init__()
|
||||||
|
if C_in != C_out:
|
||||||
|
self.layer = tf.keras.layers.Conv2D(C_out, 1, stride, padding='same', use_bias=False)
|
||||||
|
elif stride != 1:
|
||||||
|
self.layer = tf.keras.layers.AvgPool2D((stride,stride), None, padding="same")
|
||||||
|
else:
|
||||||
|
self.layer = None
|
||||||
|
|
||||||
|
def call(self, inputs, training):
|
||||||
|
x = tf.zeros_like(inputs)
|
||||||
|
if self.layer is not None:
|
||||||
|
x = self.layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ReLUConvBN(tf.keras.layers.Layer):
|
||||||
|
def __init__(self, C_in, C_out, kernel_size, strides, affine):
|
||||||
|
super(ReLUConvBN, self).__init__()
|
||||||
|
self.C_in = C_in
|
||||||
|
self.relu = tf.keras.activations.relu
|
||||||
|
self.conv = tf.keras.layers.Conv2D(C_out, kernel_size, strides, padding='same', use_bias=False)
|
||||||
|
self.bn = tf.keras.layers.BatchNormalization(center=affine, scale=affine)
|
||||||
|
|
||||||
|
def call(self, inputs, training):
|
||||||
|
x = self.relu(inputs)
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x, training)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetBasicblock(tf.keras.layers.Layer):
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride, affine=True):
|
||||||
|
super(ResNetBasicblock, self).__init__()
|
||||||
|
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||||
|
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, affine)
|
||||||
|
self.conv_b = ReLUConvBN( planes, planes, 3, 1, affine)
|
||||||
|
if stride == 2:
|
||||||
|
self.downsample = tf.keras.Sequential([
|
||||||
|
tf.keras.layers.AvgPool2D((stride,stride), None, padding="same"),
|
||||||
|
tf.keras.layers.Conv2D(planes, 1, 1, padding='same', use_bias=False)])
|
||||||
|
elif inplanes != planes:
|
||||||
|
self.downsample = ReLUConvBN(inplanes, planes, 1, stride, affine)
|
||||||
|
else:
|
||||||
|
self.downsample = None
|
||||||
|
self.addition = tf.keras.layers.Add()
|
||||||
|
self.in_dim = inplanes
|
||||||
|
self.out_dim = planes
|
||||||
|
self.stride = stride
|
||||||
|
self.num_conv = 2
|
||||||
|
|
||||||
|
def call(self, inputs, training):
|
||||||
|
|
||||||
|
basicblock = self.conv_a(inputs, training)
|
||||||
|
basicblock = self.conv_b(basicblock, training)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(inputs)
|
||||||
|
else:
|
||||||
|
residual = inputs
|
||||||
|
return self.addition([residual, basicblock])
|
6
lib/tf_models/cell_searchs/__init__.py
Normal file
6
lib/tf_models/cell_searchs/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
##################################################
|
||||||
|
from .search_model_gdas import TinyNetworkGDAS
|
||||||
|
|
||||||
|
nas_super_nets = {'GDAS': TinyNetworkGDAS}
|
50
lib/tf_models/cell_searchs/search_cells.py
Normal file
50
lib/tf_models/cell_searchs/search_cells.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
##################################################
|
||||||
|
import math, random
|
||||||
|
import tensorflow as tf
|
||||||
|
from copy import deepcopy
|
||||||
|
from ..cell_operations import OPS
|
||||||
|
|
||||||
|
|
||||||
|
class SearchCell(tf.keras.layers.Layer):
|
||||||
|
|
||||||
|
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False):
|
||||||
|
super(SearchCell, self).__init__()
|
||||||
|
|
||||||
|
self.op_names = deepcopy(op_names)
|
||||||
|
self.max_nodes = max_nodes
|
||||||
|
self.in_dim = C_in
|
||||||
|
self.out_dim = C_out
|
||||||
|
self.edge_keys = []
|
||||||
|
for i in range(1, max_nodes):
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
if j == 0:
|
||||||
|
xlists = [OPS[op_name](C_in , C_out, stride, affine) for op_name in op_names]
|
||||||
|
else:
|
||||||
|
xlists = [OPS[op_name](C_in , C_out, 1, affine) for op_name in op_names]
|
||||||
|
for k, op in enumerate(xlists):
|
||||||
|
setattr(self, '{:}.{:}'.format(node_str, k), op)
|
||||||
|
self.edge_keys.append( node_str )
|
||||||
|
self.edge_keys = sorted(self.edge_keys)
|
||||||
|
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
|
||||||
|
self.num_edges = len(self.edge_keys)
|
||||||
|
|
||||||
|
def call(self, inputs, weightss, training):
|
||||||
|
w_lst = tf.split(weightss, self.num_edges, 0)
|
||||||
|
nodes = [inputs]
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
inter_nodes = []
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
edge_idx = self.edge2index[node_str]
|
||||||
|
op_outps = []
|
||||||
|
for k, op_name in enumerate(self.op_names):
|
||||||
|
op = getattr(self, '{:}.{:}'.format(node_str, k))
|
||||||
|
op_outps.append( op(nodes[j], training) )
|
||||||
|
stack_op_outs = tf.stack(op_outps, axis=-1)
|
||||||
|
weighted_sums = tf.math.multiply(stack_op_outs, w_lst[edge_idx])
|
||||||
|
inter_nodes.append( tf.math.reduce_sum(weighted_sums, axis=-1) )
|
||||||
|
nodes.append( tf.math.add_n(inter_nodes) )
|
||||||
|
return nodes[-1]
|
99
lib/tf_models/cell_searchs/search_model_gdas.py
Normal file
99
lib/tf_models/cell_searchs/search_model_gdas.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
###########################################################################
|
||||||
|
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
|
||||||
|
###########################################################################
|
||||||
|
import tensorflow as tf
|
||||||
|
import numpy as np
|
||||||
|
from copy import deepcopy
|
||||||
|
from ..cell_operations import ResNetBasicblock
|
||||||
|
from .search_cells import SearchCell
|
||||||
|
|
||||||
|
|
||||||
|
def sample_gumbel(shape, eps=1e-20):
|
||||||
|
U = tf.random.uniform(shape, minval=0, maxval=1)
|
||||||
|
return -tf.math.log(-tf.math.log(U + eps) + eps)
|
||||||
|
|
||||||
|
|
||||||
|
def gumbel_softmax(logits, temperature):
|
||||||
|
gumbel_softmax_sample = logits + sample_gumbel(tf.shape(logits))
|
||||||
|
y = tf.nn.softmax(gumbel_softmax_sample / temperature)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class TinyNetworkGDAS(tf.keras.Model):
|
||||||
|
|
||||||
|
def __init__(self, C, N, max_nodes, num_classes, search_space, affine):
|
||||||
|
super(TinyNetworkGDAS, self).__init__()
|
||||||
|
self._C = C
|
||||||
|
self._layerN = N
|
||||||
|
self.max_nodes = max_nodes
|
||||||
|
self.stem = tf.keras.Sequential([
|
||||||
|
tf.keras.layers.Conv2D(16, 3, 1, padding='same', use_bias=False),
|
||||||
|
tf.keras.layers.BatchNormalization()], name='stem')
|
||||||
|
|
||||||
|
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||||
|
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||||
|
|
||||||
|
C_prev, num_edge, edge2index = C, None, None
|
||||||
|
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||||
|
cell_prefix = 'cell-{:03d}'.format(index)
|
||||||
|
#with tf.name_scope(cell_prefix) as scope:
|
||||||
|
if reduction:
|
||||||
|
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
||||||
|
else:
|
||||||
|
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine)
|
||||||
|
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
|
||||||
|
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
|
||||||
|
C_prev = cell.out_dim
|
||||||
|
setattr(self, cell_prefix, cell)
|
||||||
|
self.num_layers = len(layer_reductions)
|
||||||
|
self.op_names = deepcopy( search_space )
|
||||||
|
self.edge2index = edge2index
|
||||||
|
self.num_edge = num_edge
|
||||||
|
self.lastact = tf.keras.Sequential([
|
||||||
|
tf.keras.layers.BatchNormalization(),
|
||||||
|
tf.keras.layers.ReLU(),
|
||||||
|
tf.keras.layers.GlobalAvgPool2D(),
|
||||||
|
tf.keras.layers.Flatten(),
|
||||||
|
tf.keras.layers.Dense(num_classes, activation='softmax')], name='lastact')
|
||||||
|
#self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
|
||||||
|
arch_init = tf.random_normal_initializer(mean=0, stddev=0.001)
|
||||||
|
self.arch_parameters = tf.Variable(initial_value=arch_init(shape=(num_edge, len(search_space)), dtype='float32'), trainable=True, name='arch-encoding')
|
||||||
|
|
||||||
|
def get_alphas(self):
|
||||||
|
xlist = self.trainable_variables
|
||||||
|
return [x for x in xlist if 'arch-encoding' in x.name]
|
||||||
|
|
||||||
|
def get_weights(self):
|
||||||
|
xlist = self.trainable_variables
|
||||||
|
return [x for x in xlist if 'arch-encoding' not in x.name]
|
||||||
|
|
||||||
|
def get_np_alphas(self):
|
||||||
|
arch_nps = self.arch_parameters.numpy()
|
||||||
|
arch_ops = np.exp(arch_nps) / np.sum(np.exp(arch_nps), axis=-1, keepdims=True)
|
||||||
|
return arch_ops
|
||||||
|
|
||||||
|
def genotype(self):
|
||||||
|
genotypes, arch_nps = [], self.arch_parameters.numpy()
|
||||||
|
for i in range(1, self.max_nodes):
|
||||||
|
xlist = []
|
||||||
|
for j in range(i):
|
||||||
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
|
weights = arch_nps[ self.edge2index[node_str] ]
|
||||||
|
op_name = self.op_names[ weights.argmax().item() ]
|
||||||
|
xlist.append((op_name, j))
|
||||||
|
genotypes.append( tuple(xlist) )
|
||||||
|
return genotypes
|
||||||
|
|
||||||
|
#
|
||||||
|
def call(self, inputs, tau, training):
|
||||||
|
weightss = tf.cond(tau < 0, lambda: tf.nn.softmax(self.arch_parameters, axis=1),
|
||||||
|
lambda: gumbel_softmax(tf.math.log_softmax(self.arch_parameters, axis=1), tau))
|
||||||
|
feature = self.stem(inputs, training)
|
||||||
|
for idx in range(self.num_layers):
|
||||||
|
cell = getattr(self, 'cell-{:03d}'.format(idx))
|
||||||
|
if isinstance(cell, SearchCell):
|
||||||
|
feature = cell.call(feature, weightss, training)
|
||||||
|
else:
|
||||||
|
feature = cell(feature, training)
|
||||||
|
logits = self.lastact(feature, training)
|
||||||
|
return logits
|
1
lib/tf_optimizers/__init__.py
Normal file
1
lib/tf_optimizers/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .weight_decay_optimizers import AdamW, SGDW
|
422
lib/tf_optimizers/weight_decay_optimizers.py
Normal file
422
lib/tf_optimizers/weight_decay_optimizers.py
Normal file
@ -0,0 +1,422 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Base class to make optimizers weight decay ready."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class DecoupledWeightDecayExtension(object):
|
||||||
|
"""This class allows to extend optimizers with decoupled weight decay.
|
||||||
|
|
||||||
|
It implements the decoupled weight decay described by Loshchilov & Hutter
|
||||||
|
(https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is
|
||||||
|
decoupled from the optimization steps w.r.t. to the loss function.
|
||||||
|
For SGD variants, this simplifies hyperparameter search since it decouples
|
||||||
|
the settings of weight decay and learning rate.
|
||||||
|
For adaptive gradient algorithms, it regularizes variables with large
|
||||||
|
gradients more than L2 regularization would, which was shown to yield
|
||||||
|
better training loss and generalization error in the paper above.
|
||||||
|
|
||||||
|
This class alone is not an optimizer but rather extends existing
|
||||||
|
optimizers with decoupled weight decay. We explicitly define the two
|
||||||
|
examples used in the above paper (SGDW and AdamW), but in general this
|
||||||
|
can extend any OptimizerX by using
|
||||||
|
`extend_with_decoupled_weight_decay(
|
||||||
|
OptimizerX, weight_decay=weight_decay)`.
|
||||||
|
In order for it to work, it must be the first class the Optimizer with
|
||||||
|
weight decay inherits from, e.g.
|
||||||
|
|
||||||
|
```python
|
||||||
|
class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
|
||||||
|
def __init__(self, weight_decay, *args, **kwargs):
|
||||||
|
super(AdamW, self).__init__(weight_decay, *args, **kwargs).
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: this extension decays weights BEFORE applying the update based
|
||||||
|
on the gradient, i.e. this extension only has the desired behaviour for
|
||||||
|
optimizers which do not depend on the value of'var' in the update step!
|
||||||
|
|
||||||
|
Note: when applying a decay to the learning rate, be sure to manually apply
|
||||||
|
the decay to the `weight_decay` as well. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
step = tf.Variable(0, trainable=False)
|
||||||
|
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
|
||||||
|
[10000, 15000], [1e-0, 1e-1, 1e-2])
|
||||||
|
# lr and wd can be a function or a tensor
|
||||||
|
lr = 1e-1 * schedule(step)
|
||||||
|
wd = lambda: 1e-4 * schedule(step)
|
||||||
|
|
||||||
|
# ...
|
||||||
|
|
||||||
|
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, weight_decay, **kwargs):
|
||||||
|
"""Extension class that adds weight decay to an optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight_decay: A `Tensor` or a floating point value, the factor by
|
||||||
|
which a variable is decayed in the update step.
|
||||||
|
**kwargs: Optional list or tuple or set of `Variable` objects to
|
||||||
|
decay.
|
||||||
|
"""
|
||||||
|
wd = kwargs.pop('weight_decay', weight_decay)
|
||||||
|
super(DecoupledWeightDecayExtension, self).__init__(**kwargs)
|
||||||
|
self._decay_var_list = None # is set in minimize or apply_gradients
|
||||||
|
self._set_hyper('weight_decay', wd)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = super(DecoupledWeightDecayExtension, self).get_config()
|
||||||
|
config.update({
|
||||||
|
'weight_decay':
|
||||||
|
self._serialize_hyperparameter('weight_decay'),
|
||||||
|
})
|
||||||
|
return config
|
||||||
|
|
||||||
|
def minimize(self,
|
||||||
|
loss,
|
||||||
|
var_list,
|
||||||
|
grad_loss=None,
|
||||||
|
name=None,
|
||||||
|
decay_var_list=None):
|
||||||
|
"""Minimize `loss` by updating `var_list`.
|
||||||
|
|
||||||
|
This method simply computes gradient using `tf.GradientTape` and calls
|
||||||
|
`apply_gradients()`. If you want to process the gradient before
|
||||||
|
applying then call `tf.GradientTape` and `apply_gradients()` explicitly
|
||||||
|
instead of using this function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss: A callable taking no arguments which returns the value to
|
||||||
|
minimize.
|
||||||
|
var_list: list or tuple of `Variable` objects to update to
|
||||||
|
minimize `loss`, or a callable returning the list or tuple of
|
||||||
|
`Variable` objects. Use callable when the variable list would
|
||||||
|
otherwise be incomplete before `minimize` since the variables
|
||||||
|
are created at the first time `loss` is called.
|
||||||
|
grad_loss: Optional. A `Tensor` holding the gradient computed for
|
||||||
|
`loss`.
|
||||||
|
decay_var_list: Optional list of variables to be decayed. Defaults
|
||||||
|
to all variables in var_list.
|
||||||
|
name: Optional name for the returned operation.
|
||||||
|
Returns:
|
||||||
|
An Operation that updates the variables in `var_list`. If
|
||||||
|
`global_step` was not `None`, that operation also increments
|
||||||
|
`global_step`.
|
||||||
|
Raises:
|
||||||
|
ValueError: If some of the variables are not `Variable` objects.
|
||||||
|
"""
|
||||||
|
self._decay_var_list = set(decay_var_list) if decay_var_list else False
|
||||||
|
return super(DecoupledWeightDecayExtension, self).minimize(
|
||||||
|
loss, var_list=var_list, grad_loss=grad_loss, name=name)
|
||||||
|
|
||||||
|
def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None):
|
||||||
|
"""Apply gradients to variables.
|
||||||
|
|
||||||
|
This is the second part of `minimize()`. It returns an `Operation` that
|
||||||
|
applies gradients.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grads_and_vars: List of (gradient, variable) pairs.
|
||||||
|
name: Optional name for the returned operation. Default to the
|
||||||
|
name passed to the `Optimizer` constructor.
|
||||||
|
decay_var_list: Optional list of variables to be decayed. Defaults
|
||||||
|
to all variables in var_list.
|
||||||
|
Returns:
|
||||||
|
An `Operation` that applies the specified gradients. If
|
||||||
|
`global_step` was not None, that operation also increments
|
||||||
|
`global_step`.
|
||||||
|
Raises:
|
||||||
|
TypeError: If `grads_and_vars` is malformed.
|
||||||
|
ValueError: If none of the variables have gradients.
|
||||||
|
"""
|
||||||
|
self._decay_var_list = set(decay_var_list) if decay_var_list else False
|
||||||
|
return super(DecoupledWeightDecayExtension, self).apply_gradients(
|
||||||
|
grads_and_vars, name=name)
|
||||||
|
|
||||||
|
def _decay_weights_op(self, var):
|
||||||
|
if not self._decay_var_list or var in self._decay_var_list:
|
||||||
|
return var.assign_sub(
|
||||||
|
self._get_hyper('weight_decay', var.dtype) * var,
|
||||||
|
self._use_locking)
|
||||||
|
return tf.no_op()
|
||||||
|
|
||||||
|
def _decay_weights_sparse_op(self, var, indices):
|
||||||
|
if not self._decay_var_list or var in self._decay_var_list:
|
||||||
|
update = (-self._get_hyper('weight_decay', var.dtype) * tf.gather(
|
||||||
|
var, indices))
|
||||||
|
return self._resource_scatter_add(var, indices, update)
|
||||||
|
return tf.no_op()
|
||||||
|
|
||||||
|
# Here, we overwrite the apply functions that the base optimizer calls.
|
||||||
|
# super().apply_x resolves to the apply_x function of the BaseOptimizer.
|
||||||
|
|
||||||
|
def _resource_apply_dense(self, grad, var):
|
||||||
|
with tf.control_dependencies([self._decay_weights_op(var)]):
|
||||||
|
return super(DecoupledWeightDecayExtension,
|
||||||
|
self)._resource_apply_dense(grad, var)
|
||||||
|
|
||||||
|
def _resource_apply_sparse(self, grad, var, indices):
|
||||||
|
decay_op = self._decay_weights_sparse_op(var, indices)
|
||||||
|
with tf.control_dependencies([decay_op]):
|
||||||
|
return super(DecoupledWeightDecayExtension,
|
||||||
|
self)._resource_apply_sparse(grad, var, indices)
|
||||||
|
|
||||||
|
|
||||||
|
def extend_with_decoupled_weight_decay(base_optimizer):
|
||||||
|
"""Factory function returning an optimizer class with decoupled weight
|
||||||
|
decay.
|
||||||
|
|
||||||
|
Returns an optimizer class. An instance of the returned class computes the
|
||||||
|
update step of `base_optimizer` and additionally decays the weights.
|
||||||
|
E.g., the class returned by
|
||||||
|
`extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)` is
|
||||||
|
equivalent to `tfa.optimizers.AdamW`.
|
||||||
|
|
||||||
|
The API of the new optimizer class slightly differs from the API of the
|
||||||
|
base optimizer:
|
||||||
|
- The first argument to the constructor is the weight decay rate.
|
||||||
|
- `minimize` and `apply_gradients` accept the optional keyword argument
|
||||||
|
`decay_var_list`, which specifies the variables that should be decayed.
|
||||||
|
If `None`, all variables that are optimized are decayed.
|
||||||
|
|
||||||
|
Usage example:
|
||||||
|
```python
|
||||||
|
# MyAdamW is a new class
|
||||||
|
MyAdamW = extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)
|
||||||
|
# Create a MyAdamW object
|
||||||
|
optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001)
|
||||||
|
# update var1, var2 but only decay var1
|
||||||
|
optimizer.minimize(loss, var_list=[var1, var2], decay_variables=[var1])
|
||||||
|
|
||||||
|
Note: this extension decays weights BEFORE applying the update based
|
||||||
|
on the gradient, i.e. this extension only has the desired behaviour for
|
||||||
|
optimizers which do not depend on the value of 'var' in the update step!
|
||||||
|
|
||||||
|
Note: when applying a decay to the learning rate, be sure to manually apply
|
||||||
|
the decay to the `weight_decay` as well. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
step = tf.Variable(0, trainable=False)
|
||||||
|
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
|
||||||
|
[10000, 15000], [1e-0, 1e-1, 1e-2])
|
||||||
|
# lr and wd can be a function or a tensor
|
||||||
|
lr = 1e-1 * schedule(step)
|
||||||
|
wd = lambda: 1e-4 * schedule(step)
|
||||||
|
|
||||||
|
# ...
|
||||||
|
|
||||||
|
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: you might want to register your own custom optimizer using
|
||||||
|
`tf.keras.utils.get_custom_objects()`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_optimizer: An optimizer class that inherits from
|
||||||
|
tf.optimizers.Optimizer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new optimizer class that inherits from DecoupledWeightDecayExtension
|
||||||
|
and base_optimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
|
||||||
|
base_optimizer):
|
||||||
|
"""Base_optimizer with decoupled weight decay.
|
||||||
|
|
||||||
|
This class computes the update step of `base_optimizer` and
|
||||||
|
additionally decays the variable with the weight decay being
|
||||||
|
decoupled from the optimization steps w.r.t. to the loss
|
||||||
|
function, as described by Loshchilov & Hutter
|
||||||
|
(https://arxiv.org/pdf/1711.05101.pdf). For SGD variants, this
|
||||||
|
simplifies hyperparameter search since it decouples the settings
|
||||||
|
of weight decay and learning rate. For adaptive gradient
|
||||||
|
algorithms, it regularizes variables with large gradients more
|
||||||
|
than L2 regularization would, which was shown to yield better
|
||||||
|
training loss and generalization error in the paper above.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, weight_decay, *args, **kwargs):
|
||||||
|
# super delegation is necessary here
|
||||||
|
super(OptimizerWithDecoupledWeightDecay, self).__init__(
|
||||||
|
weight_decay, *args, **kwargs)
|
||||||
|
|
||||||
|
return OptimizerWithDecoupledWeightDecay
|
||||||
|
|
||||||
|
|
||||||
|
class SGDW(DecoupledWeightDecayExtension, tf.keras.optimizers.SGD):
|
||||||
|
"""Optimizer that implements the Momentum algorithm with weight_decay.
|
||||||
|
|
||||||
|
This is an implementation of the SGDW optimizer described in "Decoupled
|
||||||
|
Weight Decay Regularization" by Loshchilov & Hutter
|
||||||
|
(https://arxiv.org/abs/1711.05101)
|
||||||
|
([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
|
||||||
|
It computes the update step of `tf.keras.optimizers.SGD` and additionally
|
||||||
|
decays the variable. Note that this is different from adding
|
||||||
|
L2 regularization on the variables to the loss. Decoupling the weight decay
|
||||||
|
from other hyperparameters (in particular the learning rate) simplifies
|
||||||
|
hyperparameter search.
|
||||||
|
|
||||||
|
For further information see the documentation of the SGD Optimizer.
|
||||||
|
|
||||||
|
This optimizer can also be instantiated as
|
||||||
|
```python
|
||||||
|
extend_with_decoupled_weight_decay(tf.keras.optimizers.SGD,
|
||||||
|
weight_decay=weight_decay)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: when applying a decay to the learning rate, be sure to manually apply
|
||||||
|
the decay to the `weight_decay` as well. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
step = tf.Variable(0, trainable=False)
|
||||||
|
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
|
||||||
|
[10000, 15000], [1e-0, 1e-1, 1e-2])
|
||||||
|
# lr and wd can be a function or a tensor
|
||||||
|
lr = 1e-1 * schedule(step)
|
||||||
|
wd = lambda: 1e-4 * schedule(step)
|
||||||
|
|
||||||
|
# ...
|
||||||
|
|
||||||
|
optimizer = tfa.optimizers.SGDW(
|
||||||
|
learning_rate=lr, weight_decay=wd, momentum=0.9)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
weight_decay,
|
||||||
|
learning_rate=0.001,
|
||||||
|
momentum=0.0,
|
||||||
|
nesterov=False,
|
||||||
|
name='SGDW',
|
||||||
|
**kwargs):
|
||||||
|
"""Construct a new SGDW optimizer.
|
||||||
|
|
||||||
|
For further information see the documentation of the SGD Optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_rate: float hyperparameter >= 0. Learning rate.
|
||||||
|
momentum: float hyperparameter >= 0 that accelerates SGD in the
|
||||||
|
relevant direction and dampens oscillations.
|
||||||
|
nesterov: boolean. Whether to apply Nesterov momentum.
|
||||||
|
name: Optional name prefix for the operations created when applying
|
||||||
|
gradients. Defaults to 'SGD'.
|
||||||
|
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
|
||||||
|
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
|
||||||
|
norm; `clipvalue` is clip gradients by value, `decay` is
|
||||||
|
included for backward compatibility to allow time inverse decay
|
||||||
|
of learning rate. `lr` is included for backward compatibility,
|
||||||
|
recommended to use `learning_rate` instead.
|
||||||
|
"""
|
||||||
|
super(SGDW, self).__init__(
|
||||||
|
weight_decay,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
momentum=momentum,
|
||||||
|
nesterov=nesterov,
|
||||||
|
name=name,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
|
||||||
|
"""Optimizer that implements the Adam algorithm with weight decay.
|
||||||
|
|
||||||
|
This is an implementation of the AdamW optimizer described in "Decoupled
|
||||||
|
Weight Decay Regularization" by Loshchilov & Hutter
|
||||||
|
(https://arxiv.org/abs/1711.05101)
|
||||||
|
([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
|
||||||
|
|
||||||
|
It computes the update step of `tf.keras.optimizers.Adam` and additionally
|
||||||
|
decays the variable. Note that this is different from adding L2
|
||||||
|
regularization on the variables to the loss: it regularizes variables with
|
||||||
|
large gradients more than L2 regularization would, which was shown to yield
|
||||||
|
better training loss and generalization error in the paper above.
|
||||||
|
|
||||||
|
For further information see the documentation of the Adam Optimizer.
|
||||||
|
|
||||||
|
This optimizer can also be instantiated as
|
||||||
|
```python
|
||||||
|
extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam,
|
||||||
|
weight_decay=weight_decay)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: when applying a decay to the learning rate, be sure to manually apply
|
||||||
|
the decay to the `weight_decay` as well. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
step = tf.Variable(0, trainable=False)
|
||||||
|
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
|
||||||
|
[10000, 15000], [1e-0, 1e-1, 1e-2])
|
||||||
|
# lr and wd can be a function or a tensor
|
||||||
|
lr = 1e-1 * schedule(step)
|
||||||
|
wd = lambda: 1e-4 * schedule(step)
|
||||||
|
|
||||||
|
# ...
|
||||||
|
|
||||||
|
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
weight_decay,
|
||||||
|
learning_rate=0.001,
|
||||||
|
beta_1=0.9,
|
||||||
|
beta_2=0.999,
|
||||||
|
epsilon=1e-07,
|
||||||
|
amsgrad=False,
|
||||||
|
name="AdamW",
|
||||||
|
**kwargs):
|
||||||
|
"""Construct a new AdamW optimizer.
|
||||||
|
|
||||||
|
For further information see the documentation of the Adam Optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight_decay: A Tensor or a floating point value. The weight decay.
|
||||||
|
learning_rate: A Tensor or a floating point value. The learning
|
||||||
|
rate.
|
||||||
|
beta_1: A float value or a constant float tensor. The exponential
|
||||||
|
decay rate for the 1st moment estimates.
|
||||||
|
beta_2: A float value or a constant float tensor. The exponential
|
||||||
|
decay rate for the 2nd moment estimates.
|
||||||
|
epsilon: A small constant for numerical stability. This epsilon is
|
||||||
|
"epsilon hat" in the Kingma and Ba paper (in the formula just
|
||||||
|
before Section 2.1), not the epsilon in Algorithm 1 of the
|
||||||
|
paper.
|
||||||
|
amsgrad: boolean. Whether to apply AMSGrad variant of this
|
||||||
|
algorithm from the paper "On the Convergence of Adam and
|
||||||
|
beyond".
|
||||||
|
name: Optional name for the operations created when applying
|
||||||
|
gradients. Defaults to "AdamW".
|
||||||
|
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
|
||||||
|
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
|
||||||
|
norm; `clipvalue` is clip gradients by value, `decay` is
|
||||||
|
included for backward compatibility to allow time inverse decay
|
||||||
|
of learning rate. `lr` is included for backward compatibility,
|
||||||
|
recommended to use `learning_rate` instead.
|
||||||
|
"""
|
||||||
|
super(AdamW, self).__init__(
|
||||||
|
weight_decay,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
beta_1=beta_1,
|
||||||
|
beta_2=beta_2,
|
||||||
|
epsilon=epsilon,
|
||||||
|
amsgrad=amsgrad,
|
||||||
|
name=name,
|
||||||
|
**kwargs)
|
Loading…
Reference in New Issue
Block a user