update TF models (beta version)

This commit is contained in:
D-X-Y 2020-01-05 22:19:38 +11:00
parent e6ca3628ce
commit 5ac5060a33
18 changed files with 1253 additions and 44 deletions

View File

@ -0,0 +1,7 @@
{
"dataset" : ["str", "cifar"],
"arch" : ["str", "simres"],
"depth" : ["int", 5],
"super_type": ["str" , "basic"],
"zero_init_residual" : ["bool", "0"]
}

144
exps-tf/GDAS.py Normal file
View File

@ -0,0 +1,144 @@
# CUDA_VISIBLE_DEVICES=0 python exps-tf/GDAS.py
import os, sys, time, random, argparse
import tensorflow as tf
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
# self-lib
from tf_models import get_cell_based_tiny_net
from tf_optimizers import SGDW, AdamW
from config_utils import dict2config
from log_utils import time_string
from models import CellStructure
def pre_process(image_a, label_a, image_b, label_b):
def standard_func(image):
x = tf.pad(image, [[4, 4], [4, 4], [0, 0]])
x = tf.image.random_crop(x, [32, 32, 3])
x = tf.image.random_flip_left_right(x)
return x
return standard_func(image_a), label_a, standard_func(image_b), label_b
def main(xargs):
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train, x_test = x_train.astype('float32'), x_test.astype('float32')
# Add a channels dimension
all_indexes = list(range(x_train.shape[0]))
random.shuffle(all_indexes)
s_train_idxs, s_valid_idxs = all_indexes[::2], all_indexes[1::2]
search_train_x, search_train_y = x_train[s_train_idxs], y_train[s_train_idxs]
search_valid_x, search_valid_y = x_train[s_valid_idxs], y_train[s_valid_idxs]
#x_train, x_test = x_train[..., tf.newaxis], x_test[..., tf.newaxis]
# Use tf.data
#train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(64)
search_ds = tf.data.Dataset.from_tensor_slices((search_train_x, search_train_y, search_valid_x, search_valid_y))
search_ds = search_ds.map(pre_process).shuffle(1000).batch(64)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
# Create an instance of the model
config = dict2config({'name': 'GDAS',
'C' : xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes,
'num_classes': 10, 'space': 'nas-bench-102', 'affine': True}, None)
model = get_cell_based_tiny_net(config)
#import pdb; pdb.set_trace()
#model.build(((64, 32, 32, 3), (1,)))
#for x in model.trainable_variables:
# print('{:30s} : {:}'.format(x.name, x.shape))
# Choose optimizer
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
w_optimizer = SGDW(learning_rate=xargs.w_lr, weight_decay=xargs.w_weight_decay, momentum=xargs.w_momentum, nesterov=True)
a_optimizer = AdamW(learning_rate=xargs.arch_learning_rate, weight_decay=xargs.arch_weight_decay, beta_1=0.5, beta_2=0.999, epsilon=1e-07)
#w_optimizer = tf.keras.optimizers.SGD(learning_rate=0.025, momentum=0.9, nesterov=True)
#a_optimizer = tf.keras.optimizers.AdamW(learning_rate=xargs.arch_learning_rate, beta_1=0.5, beta_2=0.999, epsilon=1e-07)
####
# metrics
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
valid_loss = tf.keras.metrics.Mean(name='valid_loss')
valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='valid_accuracy')
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
@tf.function
def search_step(train_images, train_labels, valid_images, valid_labels, tf_tau):
# optimize weights
with tf.GradientTape() as tape:
predictions = model(train_images, tf_tau, True)
w_loss = loss_object(train_labels, predictions)
net_w_param = model.get_weights()
gradients = tape.gradient(w_loss, net_w_param)
w_optimizer.apply_gradients(zip(gradients, net_w_param))
train_loss(w_loss)
train_accuracy(train_labels, predictions)
# optimize alphas
with tf.GradientTape() as tape:
predictions = model(valid_images, tf_tau, True)
a_loss = loss_object(valid_labels, predictions)
net_a_param = model.get_alphas()
gradients = tape.gradient(a_loss, net_a_param)
a_optimizer.apply_gradients(zip(gradients, net_a_param))
valid_loss(a_loss)
valid_accuracy(valid_labels, predictions)
# TEST
@tf.function
def test_step(images, labels):
predictions = model(images)
t_loss = loss_object(labels, predictions)
test_loss(t_loss)
test_accuracy(labels, predictions)
print('{:} start searching with {:} epochs ({:} batches per epoch).'.format(time_string(), xargs.epochs, tf.data.experimental.cardinality(search_ds).numpy()))
for epoch in range(xargs.epochs):
# Reset the metrics at the start of the next epoch
train_loss.reset_states() ; train_accuracy.reset_states()
test_loss.reset_states() ; test_accuracy.reset_states()
cur_tau = xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (xargs.epochs-1)
tf_tau = tf.cast(cur_tau, dtype=tf.float32, name='tau')
for trn_imgs, trn_labels, val_imgs, val_labels in search_ds:
search_step(trn_imgs, trn_labels, val_imgs, val_labels, tf_tau)
genotype = model.genotype()
genotype = CellStructure(genotype)
#for test_images, test_labels in test_ds:
# test_step(test_images, test_labels)
template = '{:} Epoch {:03d}/{:03d}, Train-Loss: {:.3f}, Train-Accuracy: {:.2f}%, Valid-Loss: {:.3f}, Valid-Accuracy: {:.2f}% | tau={:.3f}'
print(template.format(time_string(), epoch+1, xargs.epochs,
train_loss.result(),
train_accuracy.result()*100,
valid_loss.result(),
valid_accuracy.result()*100,
cur_tau))
print('{:} genotype : {:}\n{:}\n'.format(time_string(), genotype, model.get_np_alphas()))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# training details
parser.add_argument('--epochs' , type=int , default= 250 , help='')
parser.add_argument('--tau_max' , type=float, default= 10 , help='')
parser.add_argument('--tau_min' , type=float, default= 0.1 , help='')
parser.add_argument('--w_lr' , type=float, default= 0.025, help='')
parser.add_argument('--w_weight_decay' , type=float, default=0.0005, help='')
parser.add_argument('--w_momentum' , type=float, default= 0.9 , help='')
parser.add_argument('--arch_learning_rate', type=float, default=0.0003, help='')
parser.add_argument('--arch_weight_decay' , type=float, default=0.001, help='')
# marco structure
parser.add_argument('--channel' , type=int , default=16, help='')
parser.add_argument('--num_cells' , type=int , default= 5, help='')
parser.add_argument('--max_nodes' , type=int , default= 4, help='')
args = parser.parse_args()
main( args )

View File

@ -6,7 +6,6 @@ import numpy as np
from collections import OrderedDict from collections import OrderedDict
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from graphviz import Digraph
def test_nas_api(): def test_nas_api():
@ -29,6 +28,7 @@ OPS = ['skip-connect', 'conv-1x1', 'conv-3x3', 'pool-3x3']
COLORS = ['chartreuse' , 'cyan' , 'navyblue', 'chocolate1'] COLORS = ['chartreuse' , 'cyan' , 'navyblue', 'chocolate1']
def plot(filename): def plot(filename):
from graphviz import Digraph
g = Digraph( g = Digraph(
format='png', format='png',
edge_attr=dict(fontsize='20', fontname="times"), edge_attr=dict(fontsize='20', fontname="times"),
@ -53,6 +53,26 @@ def plot(filename):
g.render(filename, cleanup=True, view=False) g.render(filename, cleanup=True, view=False)
def test_auto_grad():
class Net(torch.nn.Module):
def __init__(self, iS):
super(Net, self).__init__()
self.layer = torch.nn.Linear(iS, 1)
def forward(self, inputs):
outputs = self.layer(inputs)
outputs = torch.exp(outputs)
return outputs.mean()
net = Net(10)
inputs = torch.rand(256, 10)
loss = net(inputs)
first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True)
first_order_grads = torch.cat([x.view(-1) for x in first_order_grads])
second_order_grads = []
for grads in first_order_grads:
s_grads = torch.autograd.grad(grads, net.parameters())
second_order_grads.append( s_grads )
if __name__ == '__main__': if __name__ == '__main__':
test_nas_api() #test_nas_api()
for i in range(200): plot('{:04d}'.format(i)) #for i in range(200): plot('{:04d}'.format(i))
test_auto_grad()

View File

@ -1,7 +1,8 @@
################################################## ##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
################################################## ##################################################
from .logger import Logger # every package does not rely on pytorch or tensorflow
from .print_logger import PrintLogger # I tried to list all dependency here: os, sys, time, numpy, (possibly) matplotlib
from .logger import Logger, PrintLogger
from .meter import AverageMeter from .meter import AverageMeter
from .time_utils import time_for_file, time_string, time_string_short, time_print, convert_size2str, convert_secs2time from .time_utils import time_for_file, time_string, time_string_short, time_print, convert_secs2time

View File

@ -1,9 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. ##################################################
# All rights reserved. # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
# ##################################################
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from pathlib import Path from pathlib import Path
import importlib, warnings import importlib, warnings
import os, sys, time, numpy as np import os, sys, time, numpy as np
@ -16,6 +13,19 @@ if importlib.util.find_spec('tensorflow'):
import tensorflow as tf import tensorflow as tf
class PrintLogger(object):
def __init__(self):
"""Create a summary writer logging to log_dir."""
self.name = 'PrintLogger'
def log(self, string):
print (string)
def close(self):
print ('-'*30 + ' close printer ' + '-'*30)
class Logger(object): class Logger(object):
def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False): def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False):

View File

@ -1,4 +1,3 @@
import time, sys
import numpy as np import numpy as np

View File

@ -1,14 +0,0 @@
import os, sys, time
class PrintLogger(object):
def __init__(self):
"""Create a summary writer logging to log_dir."""
self.name = 'PrintLogger'
def log(self, string):
print (string)
def close(self):
print ('-'*30 + ' close printer ' + '-'*30)

View File

@ -1,37 +1,27 @@
# Copyright (c) Facebook, Inc. and its affiliates. ##################################################
# All rights reserved. # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
# ##################################################
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import time, sys import time, sys
import numpy as np import numpy as np
def time_for_file(): def time_for_file():
ISOTIMEFORMAT='%d-%h-at-%H-%M-%S' ISOTIMEFORMAT='%d-%h-at-%H-%M-%S'
return '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) return '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
def time_string(): def time_string():
ISOTIMEFORMAT='%Y-%m-%d %X' ISOTIMEFORMAT='%Y-%m-%d %X'
string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) string = '[{:}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string return string
def time_string_short(): def time_string_short():
ISOTIMEFORMAT='%Y%m%d' ISOTIMEFORMAT='%Y%m%d'
string = '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) string = '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
return string return string
def time_print(string, is_print=True): def time_print(string, is_print=True):
if (is_print): if (is_print):
print('{} : {}'.format(time_string(), string)) print('{} : {}'.format(time_string(), string))
def convert_size2str(torch_size):
dims = len(torch_size)
string = '['
for idim in range(dims):
string = string + ' {}'.format(torch_size[idim])
return string + ']'
def convert_secs2time(epoch_time, return_str=False): def convert_secs2time(epoch_time, return_str=False):
need_hour = int(epoch_time / 3600) need_hour = int(epoch_time / 3600)
need_mins = int((epoch_time - 3600*need_hour) / 60) need_mins = int((epoch_time - 3600*need_hour) / 60)

View File

@ -1,7 +1,6 @@
################################################## ##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
################################################## ##################################################
import torch
from os import path as osp from os import path as osp
__all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \ __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \
@ -126,6 +125,11 @@ def obtain_search_model(config):
elif config.search_mode == 'shape': elif config.search_mode == 'shape':
return SearchShapeCifarResNet(config.module, config.depth, config.class_num) return SearchShapeCifarResNet(config.module, config.depth, config.class_num)
else: raise ValueError('invalid search mode : {:}'.format(config.search_mode)) else: raise ValueError('invalid search mode : {:}'.format(config.search_mode))
elif config.arch == 'simres':
from .shape_searchs import SearchWidthSimResNet
if config.search_mode == 'width':
return SearchWidthSimResNet(config.depth, config.class_num)
else: raise ValueError('invalid search mode : {:}'.format(config.search_mode))
else: else:
raise ValueError('invalid arch : {:} for dataset [{:}]'.format(config.arch, config.dataset)) raise ValueError('invalid arch : {:} for dataset [{:}]'.format(config.arch, config.dataset))
elif config.dataset == 'imagenet': elif config.dataset == 'imagenet':
@ -140,6 +144,7 @@ def obtain_search_model(config):
def load_net_from_checkpoint(checkpoint): def load_net_from_checkpoint(checkpoint):
import torch
assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint) assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint)
checkpoint = torch.load(checkpoint) checkpoint = torch.load(checkpoint)
model_config = dict2config(checkpoint['model-config'], None) model_config = dict2config(checkpoint['model-config'], None)

View File

@ -0,0 +1,316 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, torch
import torch.nn as nn
from ..initialization import initialize_resnet
from ..SharedUtils import additive_func
from .SoftSelect import select2withP, ChannelWiseInter
from .SoftSelect import linear_forward
from .SoftSelect import get_width_choices as get_choices
def conv_forward(inputs, conv, choices):
iC = conv.in_channels
fill_size = list(inputs.size())
fill_size[1] = iC - fill_size[1]
filled = torch.zeros(fill_size, device=inputs.device)
xinputs = torch.cat((inputs, filled), dim=1)
outputs = conv(xinputs)
selecteds = [outputs[:,:oC] for oC in choices]
return selecteds
class ConvBNReLU(nn.Module):
num_conv = 1
def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu):
super(ConvBNReLU, self).__init__()
self.InShape = None
self.OutShape = None
self.choices = get_choices(nOut)
self.register_buffer('choices_tensor', torch.Tensor( self.choices ))
if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else : self.avg = None
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias)
#if has_bn : self.bn = nn.BatchNorm2d(nOut)
#else : self.bn = None
self.has_bn = has_bn
self.BNs = nn.ModuleList()
for i, _out in enumerate(self.choices):
self.BNs.append(nn.BatchNorm2d(_out))
if has_relu: self.relu = nn.ReLU(inplace=True)
else : self.relu = None
self.in_dim = nIn
self.out_dim = nOut
self.search_mode = 'basic'
def get_flops(self, channels, check_range=True, divide=1):
iC, oC = channels
if check_range: assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels)
assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape)
assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape)
#conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups
conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups)
all_positions = self.OutShape[0] * self.OutShape[1]
flops = (conv_per_position_flops * all_positions / divide) * iC * oC
if self.conv.bias is not None: flops += all_positions / divide
return flops
def get_range(self):
return [self.choices]
def forward(self, inputs):
if self.search_mode == 'basic':
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, index, prob = tuple_inputs
index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob)
probability = torch.squeeze(probability)
assert len(index) == 2, 'invalid length : {:}'.format(index)
# compute expected flop
#coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability)
expected_outC = (self.choices_tensor * probability).sum()
expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6)
if self.avg : out = self.avg( inputs )
else : out = inputs
# convolutional layer
out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index])
out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)]
# merge
out_channel = max([x.size(1) for x in out_bns])
outA = ChannelWiseInter(out_bns[0], out_channel)
outB = ChannelWiseInter(out_bns[1], out_channel)
out = outA * prob[0] + outB * prob[1]
#out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1])
if self.relu: out = self.relu( out )
else : out = out
return out, expected_outC, expected_flop
def basic_forward(self, inputs):
if self.avg : out = self.avg( inputs )
else : out = inputs
conv = self.conv( out )
if self.has_bn:out= self.BNs[-1]( conv )
else : out = conv
if self.relu: out = self.relu( out )
else : out = out
if self.InShape is None:
self.InShape = (inputs.size(-2), inputs.size(-1))
self.OutShape = (out.size(-2) , out.size(-1))
return out
class SimBlock(nn.Module):
expansion = 1
num_conv = 1
def __init__(self, inplanes, planes, stride):
super(SimBlock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True)
if stride == 2:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False)
elif inplanes != planes:
self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False)
else:
self.downsample = None
self.out_dim = planes
self.search_mode = 'basic'
def get_range(self):
return self.conv.get_range()
def get_flops(self, channels):
assert len(channels) == 2, 'invalid channels : {:}'.format(channels)
flop_A = self.conv.get_flops([channels[0], channels[1]])
if hasattr(self.downsample, 'get_flops'):
flop_C = self.downsample.get_flops([channels[0], channels[-1]])
else:
flop_C = 0
if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train
flop_C = channels[0] * channels[-1] * self.conv.OutShape[0] * self.conv.OutShape[1]
return flop_A + flop_C
def forward(self, inputs):
if self.search_mode == 'basic' : return self.basic_forward(inputs)
elif self.search_mode == 'search': return self.search_forward(inputs)
else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, tuple_inputs):
assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) )
inputs, expected_inC, probability, indexes, probs = tuple_inputs
assert indexes.size(0) == 1 and probs.size(0) == 1 and probability.size(0) == 1, 'invalid size : {:}, {:}, {:}'.format(indexes.size(), probs.size(), probability.size())
out, expected_next_inC, expected_flop = self.conv( (inputs, expected_inC , probability[0], indexes[0], probs[0]) )
if self.downsample is not None:
residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[-1], indexes[-1], probs[-1]) )
else:
residual, expected_flop_c = inputs, 0
out = additive_func(residual, out)
return out, expected_next_inC, sum([expected_flop, expected_flop_c])
def basic_forward(self, inputs):
basicblock = self.conv(inputs)
if self.downsample is not None: residual = self.downsample(inputs)
else : residual = inputs
out = additive_func(residual, basicblock)
return nn.functional.relu(out, inplace=True)
class SearchWidthSimResNet(nn.Module):
def __init__(self, depth, num_classes):
super(SearchWidthSimResNet, self).__init__()
assert (depth - 2) % 3 == 0, 'depth should be one of 5, 8, 11, 14, ... instead of {:}'.format(depth)
layer_blocks = (depth - 2) // 3
self.message = 'SearchWidthSimResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks)
self.num_classes = num_classes
self.channels = [16]
self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] )
self.InShape = None
for stage in range(3):
for iL in range(layer_blocks):
iC = self.channels[-1]
planes = 16 * (2**stage)
stride = 2 if stage > 0 and iL == 0 else 1
module = SimBlock(iC, planes, stride)
self.channels.append( module.out_dim )
self.layers.append ( module )
self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(module.out_dim, num_classes)
self.InShape = None
self.tau = -1
self.search_mode = 'basic'
#assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth)
# parameters for width
self.Ranges = []
self.layer2indexRange = []
for i, layer in enumerate(self.layers):
start_index = len(self.Ranges)
self.Ranges += layer.get_range()
self.layer2indexRange.append( (start_index, len(self.Ranges)) )
assert len(self.Ranges) + 1 == depth, 'invalid depth check {:} vs {:}'.format(len(self.Ranges) + 1, depth)
self.register_parameter('width_attentions', nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None))))
nn.init.normal_(self.width_attentions, 0, 0.01)
self.apply(initialize_resnet)
def arch_parameters(self):
return [self.width_attentions]
def base_parameters(self):
return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters())
def get_flop(self, mode, config_dict, extra_info):
if config_dict is not None: config_dict = config_dict.copy()
#weights = [F.softmax(x, dim=0) for x in self.width_attentions]
channels = [3]
for i, weight in enumerate(self.width_attentions):
if mode == 'genotype':
with torch.no_grad():
probe = nn.functional.softmax(weight, dim=0)
C = self.Ranges[i][ torch.argmax(probe).item() ]
elif mode == 'max':
C = self.Ranges[i][-1]
elif mode == 'fix':
C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] )
elif mode == 'random':
assert isinstance(extra_info, float), 'invalid extra_info : {:}'.format(extra_info)
with torch.no_grad():
prob = nn.functional.softmax(weight, dim=0)
approximate_C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] )
for j in range(prob.size(0)):
prob[j] = 1 / (abs(j - (approximate_C-self.Ranges[i][j])) + 0.2)
C = self.Ranges[i][ torch.multinomial(prob, 1, False).item() ]
else:
raise ValueError('invalid mode : {:}'.format(mode))
channels.append( C )
flop = 0
for i, layer in enumerate(self.layers):
s, e = self.layer2indexRange[i]
xchl = tuple( channels[s:e+1] )
flop+= layer.get_flops(xchl)
# the last fc layer
flop += channels[-1] * self.classifier.out_features
if config_dict is None:
return flop / 1e6
else:
config_dict['xchannels'] = channels
config_dict['super_type'] = 'infer-width'
config_dict['estimated_FLOP'] = flop / 1e6
return flop / 1e6, config_dict
def get_arch_info(self):
string = "for width, there are {:} attention probabilities.".format(len(self.width_attentions))
discrepancy = []
with torch.no_grad():
for i, att in enumerate(self.width_attentions):
prob = nn.functional.softmax(att, dim=0)
prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist()
prob = ['{:.3f}'.format(x) for x in prob]
xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.width_attentions), ' '.join(prob))
logt = ['{:.3f}'.format(x) for x in att.cpu().tolist()]
xstring += ' || {:52s}'.format(' '.join(logt))
prob = sorted( [float(x) for x in prob] )
disc = prob[-1] - prob[-2]
xstring += ' || dis={:.2f} || select={:}/{:}'.format(disc, selc, len(prob))
discrepancy.append( disc )
string += '\n{:}'.format(xstring)
return string, discrepancy
def set_tau(self, tau_max, tau_min, epoch_ratio):
assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio)
tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2
self.tau = tau
def get_message(self):
return self.message
def forward(self, inputs):
if self.search_mode == 'basic':
return self.basic_forward(inputs)
elif self.search_mode == 'search':
return self.search_forward(inputs)
else:
raise ValueError('invalid search_mode = {:}'.format(self.search_mode))
def search_forward(self, inputs):
flop_probs = nn.functional.softmax(self.width_attentions, dim=1)
selected_widths, selected_probs = select2withP(self.width_attentions, self.tau)
with torch.no_grad():
selected_widths = selected_widths.cpu()
x, last_channel_idx, expected_inC, flops = inputs, 0, 3, []
for i, layer in enumerate(self.layers):
selected_w_index = selected_widths[last_channel_idx: last_channel_idx+layer.num_conv]
selected_w_probs = selected_probs[last_channel_idx: last_channel_idx+layer.num_conv]
layer_prob = flop_probs[last_channel_idx: last_channel_idx+layer.num_conv]
x, expected_inC, expected_flop = layer( (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) )
last_channel_idx += layer.num_conv
flops.append( expected_flop )
flops.append( expected_inC * (self.classifier.out_features*1.0/1e6) )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = linear_forward(features, self.classifier)
return logits, torch.stack( [sum(flops)] )
def basic_forward(self, inputs):
if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1))
x = inputs
for i, layer in enumerate(self.layers):
x = layer( x )
features = self.avgpool(x)
features = features.view(features.size(0), -1)
logits = self.classifier(features)
return features, logits

View File

@ -4,4 +4,5 @@
from .SearchCifarResNet_width import SearchWidthCifarResNet from .SearchCifarResNet_width import SearchWidthCifarResNet
from .SearchCifarResNet_depth import SearchDepthCifarResNet from .SearchCifarResNet_depth import SearchDepthCifarResNet
from .SearchCifarResNet import SearchShapeCifarResNet from .SearchCifarResNet import SearchShapeCifarResNet
from .SearchSimResNet_width import SearchWidthSimResNet
from .SearchImagenetResNet import SearchShapeImagenetResNet from .SearchImagenetResNet import SearchShapeImagenetResNet

32
lib/tf_models/__init__.py Normal file
View File

@ -0,0 +1,32 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
from os import path as osp
__all__ = ['get_cell_based_tiny_net', 'get_search_spaces']
# the cell-based NAS models
def get_cell_based_tiny_net(config):
group_names = ['GDAS']
if config.name in group_names:
from .cell_searchs import nas_super_nets
from .cell_operations import SearchSpaceNames
if isinstance(config.space, str): search_space = SearchSpaceNames[config.space]
else: search_space = config.space
return nas_super_nets[config.name](
config.C, config.N, config.max_nodes,
config.num_classes, search_space, config.affine)
else:
raise ValueError('invalid network name : {:}'.format(config.name))
# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op
def get_search_spaces(xtype, name):
if xtype == 'cell':
from .cell_operations import SearchSpaceNames
assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys())
return SearchSpaceNames[name]
else:
raise ValueError('invalid search-space type is {:}'.format(xtype))

View File

@ -0,0 +1,120 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import tensorflow as tf
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
OPS = {
'none' : lambda C_in, C_out, stride, affine: Zero(C_in, C_out, stride),
'avg_pool_3x3': lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg', affine),
'nor_conv_1x1': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 1, stride, affine),
'nor_conv_3x3': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 3, stride, affine),
'nor_conv_5x5': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 5, stride, affine),
'skip_connect': lambda C_in, C_out, stride, affine: Identity(C_in, C_out, stride)
}
NAS_BENCH_102 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
SearchSpaceNames = {
'nas-bench-102': NAS_BENCH_102,
}
class POOLING(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, stride, mode, affine):
super(POOLING, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, affine)
if mode == 'avg' : self.op = tf.keras.layers.AvgPool2D((3,3), strides=stride, padding='same')
elif mode == 'max': self.op = tf.keras.layers.MaxPool2D((3,3), strides=stride, padding='same')
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
def call(self, inputs, training):
if self.preprocess: x = self.preprocess(inputs)
else : x = inputs
return self.op(x)
class Identity(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, stride):
super(Identity, self).__init__()
if C_in != C_out or stride != 1:
self.layer = tf.keras.layers.Conv2D(C_out, 3, stride, padding='same', use_bias=False)
else:
self.layer = None
def call(self, inputs, training):
x = inputs
if self.layer is not None:
x = self.layer(x)
return x
class Zero(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, stride):
super(Zero, self).__init__()
if C_in != C_out:
self.layer = tf.keras.layers.Conv2D(C_out, 1, stride, padding='same', use_bias=False)
elif stride != 1:
self.layer = tf.keras.layers.AvgPool2D((stride,stride), None, padding="same")
else:
self.layer = None
def call(self, inputs, training):
x = tf.zeros_like(inputs)
if self.layer is not None:
x = self.layer(x)
return x
class ReLUConvBN(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, kernel_size, strides, affine):
super(ReLUConvBN, self).__init__()
self.C_in = C_in
self.relu = tf.keras.activations.relu
self.conv = tf.keras.layers.Conv2D(C_out, kernel_size, strides, padding='same', use_bias=False)
self.bn = tf.keras.layers.BatchNormalization(center=affine, scale=affine)
def call(self, inputs, training):
x = self.relu(inputs)
x = self.conv(x)
x = self.bn(x, training)
return x
class ResNetBasicblock(tf.keras.layers.Layer):
def __init__(self, inplanes, planes, stride, affine=True):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, affine)
self.conv_b = ReLUConvBN( planes, planes, 3, 1, affine)
if stride == 2:
self.downsample = tf.keras.Sequential([
tf.keras.layers.AvgPool2D((stride,stride), None, padding="same"),
tf.keras.layers.Conv2D(planes, 1, 1, padding='same', use_bias=False)])
elif inplanes != planes:
self.downsample = ReLUConvBN(inplanes, planes, 1, stride, affine)
else:
self.downsample = None
self.addition = tf.keras.layers.Add()
self.in_dim = inplanes
self.out_dim = planes
self.stride = stride
self.num_conv = 2
def call(self, inputs, training):
basicblock = self.conv_a(inputs, training)
basicblock = self.conv_b(basicblock, training)
if self.downsample is not None:
residual = self.downsample(inputs)
else:
residual = inputs
return self.addition([residual, basicblock])

View File

@ -0,0 +1,6 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
from .search_model_gdas import TinyNetworkGDAS
nas_super_nets = {'GDAS': TinyNetworkGDAS}

View File

@ -0,0 +1,50 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import math, random
import tensorflow as tf
from copy import deepcopy
from ..cell_operations import OPS
class SearchCell(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False):
super(SearchCell, self).__init__()
self.op_names = deepcopy(op_names)
self.max_nodes = max_nodes
self.in_dim = C_in
self.out_dim = C_out
self.edge_keys = []
for i in range(1, max_nodes):
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
if j == 0:
xlists = [OPS[op_name](C_in , C_out, stride, affine) for op_name in op_names]
else:
xlists = [OPS[op_name](C_in , C_out, 1, affine) for op_name in op_names]
for k, op in enumerate(xlists):
setattr(self, '{:}.{:}'.format(node_str, k), op)
self.edge_keys.append( node_str )
self.edge_keys = sorted(self.edge_keys)
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
self.num_edges = len(self.edge_keys)
def call(self, inputs, weightss, training):
w_lst = tf.split(weightss, self.num_edges, 0)
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
edge_idx = self.edge2index[node_str]
op_outps = []
for k, op_name in enumerate(self.op_names):
op = getattr(self, '{:}.{:}'.format(node_str, k))
op_outps.append( op(nodes[j], training) )
stack_op_outs = tf.stack(op_outps, axis=-1)
weighted_sums = tf.math.multiply(stack_op_outs, w_lst[edge_idx])
inter_nodes.append( tf.math.reduce_sum(weighted_sums, axis=-1) )
nodes.append( tf.math.add_n(inter_nodes) )
return nodes[-1]

View File

@ -0,0 +1,99 @@
###########################################################################
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 #
###########################################################################
import tensorflow as tf
import numpy as np
from copy import deepcopy
from ..cell_operations import ResNetBasicblock
from .search_cells import SearchCell
def sample_gumbel(shape, eps=1e-20):
U = tf.random.uniform(shape, minval=0, maxval=1)
return -tf.math.log(-tf.math.log(U + eps) + eps)
def gumbel_softmax(logits, temperature):
gumbel_softmax_sample = logits + sample_gumbel(tf.shape(logits))
y = tf.nn.softmax(gumbel_softmax_sample / temperature)
return y
class TinyNetworkGDAS(tf.keras.Model):
def __init__(self, C, N, max_nodes, num_classes, search_space, affine):
super(TinyNetworkGDAS, self).__init__()
self._C = C
self._layerN = N
self.max_nodes = max_nodes
self.stem = tf.keras.Sequential([
tf.keras.layers.Conv2D(16, 3, 1, padding='same', use_bias=False),
tf.keras.layers.BatchNormalization()], name='stem')
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev, num_edge, edge2index = C, None, None
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
cell_prefix = 'cell-{:03d}'.format(index)
#with tf.name_scope(cell_prefix) as scope:
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine)
if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index
else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges)
C_prev = cell.out_dim
setattr(self, cell_prefix, cell)
self.num_layers = len(layer_reductions)
self.op_names = deepcopy( search_space )
self.edge2index = edge2index
self.num_edge = num_edge
self.lastact = tf.keras.Sequential([
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.GlobalAvgPool2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(num_classes, activation='softmax')], name='lastact')
#self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) )
arch_init = tf.random_normal_initializer(mean=0, stddev=0.001)
self.arch_parameters = tf.Variable(initial_value=arch_init(shape=(num_edge, len(search_space)), dtype='float32'), trainable=True, name='arch-encoding')
def get_alphas(self):
xlist = self.trainable_variables
return [x for x in xlist if 'arch-encoding' in x.name]
def get_weights(self):
xlist = self.trainable_variables
return [x for x in xlist if 'arch-encoding' not in x.name]
def get_np_alphas(self):
arch_nps = self.arch_parameters.numpy()
arch_ops = np.exp(arch_nps) / np.sum(np.exp(arch_nps), axis=-1, keepdims=True)
return arch_ops
def genotype(self):
genotypes, arch_nps = [], self.arch_parameters.numpy()
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
weights = arch_nps[ self.edge2index[node_str] ]
op_name = self.op_names[ weights.argmax().item() ]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return genotypes
#
def call(self, inputs, tau, training):
weightss = tf.cond(tau < 0, lambda: tf.nn.softmax(self.arch_parameters, axis=1),
lambda: gumbel_softmax(tf.math.log_softmax(self.arch_parameters, axis=1), tau))
feature = self.stem(inputs, training)
for idx in range(self.num_layers):
cell = getattr(self, 'cell-{:03d}'.format(idx))
if isinstance(cell, SearchCell):
feature = cell.call(feature, weightss, training)
else:
feature = cell(feature, training)
logits = self.lastact(feature, training)
return logits

View File

@ -0,0 +1 @@
from .weight_decay_optimizers import AdamW, SGDW

View File

@ -0,0 +1,422 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class to make optimizers weight decay ready."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class DecoupledWeightDecayExtension(object):
"""This class allows to extend optimizers with decoupled weight decay.
It implements the decoupled weight decay described by Loshchilov & Hutter
(https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is
decoupled from the optimization steps w.r.t. to the loss function.
For SGD variants, this simplifies hyperparameter search since it decouples
the settings of weight decay and learning rate.
For adaptive gradient algorithms, it regularizes variables with large
gradients more than L2 regularization would, which was shown to yield
better training loss and generalization error in the paper above.
This class alone is not an optimizer but rather extends existing
optimizers with decoupled weight decay. We explicitly define the two
examples used in the above paper (SGDW and AdamW), but in general this
can extend any OptimizerX by using
`extend_with_decoupled_weight_decay(
OptimizerX, weight_decay=weight_decay)`.
In order for it to work, it must be the first class the Optimizer with
weight decay inherits from, e.g.
```python
class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
def __init__(self, weight_decay, *args, **kwargs):
super(AdamW, self).__init__(weight_decay, *args, **kwargs).
```
Note: this extension decays weights BEFORE applying the update based
on the gradient, i.e. this extension only has the desired behaviour for
optimizers which do not depend on the value of'var' in the update step!
Note: when applying a decay to the learning rate, be sure to manually apply
the decay to the `weight_decay` as well. For example:
```python
step = tf.Variable(0, trainable=False)
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
[10000, 15000], [1e-0, 1e-1, 1e-2])
# lr and wd can be a function or a tensor
lr = 1e-1 * schedule(step)
wd = lambda: 1e-4 * schedule(step)
# ...
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
```
"""
def __init__(self, weight_decay, **kwargs):
"""Extension class that adds weight decay to an optimizer.
Args:
weight_decay: A `Tensor` or a floating point value, the factor by
which a variable is decayed in the update step.
**kwargs: Optional list or tuple or set of `Variable` objects to
decay.
"""
wd = kwargs.pop('weight_decay', weight_decay)
super(DecoupledWeightDecayExtension, self).__init__(**kwargs)
self._decay_var_list = None # is set in minimize or apply_gradients
self._set_hyper('weight_decay', wd)
def get_config(self):
config = super(DecoupledWeightDecayExtension, self).get_config()
config.update({
'weight_decay':
self._serialize_hyperparameter('weight_decay'),
})
return config
def minimize(self,
loss,
var_list,
grad_loss=None,
name=None,
decay_var_list=None):
"""Minimize `loss` by updating `var_list`.
This method simply computes gradient using `tf.GradientTape` and calls
`apply_gradients()`. If you want to process the gradient before
applying then call `tf.GradientTape` and `apply_gradients()` explicitly
instead of using this function.
Args:
loss: A callable taking no arguments which returns the value to
minimize.
var_list: list or tuple of `Variable` objects to update to
minimize `loss`, or a callable returning the list or tuple of
`Variable` objects. Use callable when the variable list would
otherwise be incomplete before `minimize` since the variables
are created at the first time `loss` is called.
grad_loss: Optional. A `Tensor` holding the gradient computed for
`loss`.
decay_var_list: Optional list of variables to be decayed. Defaults
to all variables in var_list.
name: Optional name for the returned operation.
Returns:
An Operation that updates the variables in `var_list`. If
`global_step` was not `None`, that operation also increments
`global_step`.
Raises:
ValueError: If some of the variables are not `Variable` objects.
"""
self._decay_var_list = set(decay_var_list) if decay_var_list else False
return super(DecoupledWeightDecayExtension, self).minimize(
loss, var_list=var_list, grad_loss=grad_loss, name=name)
def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None):
"""Apply gradients to variables.
This is the second part of `minimize()`. It returns an `Operation` that
applies gradients.
Args:
grads_and_vars: List of (gradient, variable) pairs.
name: Optional name for the returned operation. Default to the
name passed to the `Optimizer` constructor.
decay_var_list: Optional list of variables to be decayed. Defaults
to all variables in var_list.
Returns:
An `Operation` that applies the specified gradients. If
`global_step` was not None, that operation also increments
`global_step`.
Raises:
TypeError: If `grads_and_vars` is malformed.
ValueError: If none of the variables have gradients.
"""
self._decay_var_list = set(decay_var_list) if decay_var_list else False
return super(DecoupledWeightDecayExtension, self).apply_gradients(
grads_and_vars, name=name)
def _decay_weights_op(self, var):
if not self._decay_var_list or var in self._decay_var_list:
return var.assign_sub(
self._get_hyper('weight_decay', var.dtype) * var,
self._use_locking)
return tf.no_op()
def _decay_weights_sparse_op(self, var, indices):
if not self._decay_var_list or var in self._decay_var_list:
update = (-self._get_hyper('weight_decay', var.dtype) * tf.gather(
var, indices))
return self._resource_scatter_add(var, indices, update)
return tf.no_op()
# Here, we overwrite the apply functions that the base optimizer calls.
# super().apply_x resolves to the apply_x function of the BaseOptimizer.
def _resource_apply_dense(self, grad, var):
with tf.control_dependencies([self._decay_weights_op(var)]):
return super(DecoupledWeightDecayExtension,
self)._resource_apply_dense(grad, var)
def _resource_apply_sparse(self, grad, var, indices):
decay_op = self._decay_weights_sparse_op(var, indices)
with tf.control_dependencies([decay_op]):
return super(DecoupledWeightDecayExtension,
self)._resource_apply_sparse(grad, var, indices)
def extend_with_decoupled_weight_decay(base_optimizer):
"""Factory function returning an optimizer class with decoupled weight
decay.
Returns an optimizer class. An instance of the returned class computes the
update step of `base_optimizer` and additionally decays the weights.
E.g., the class returned by
`extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)` is
equivalent to `tfa.optimizers.AdamW`.
The API of the new optimizer class slightly differs from the API of the
base optimizer:
- The first argument to the constructor is the weight decay rate.
- `minimize` and `apply_gradients` accept the optional keyword argument
`decay_var_list`, which specifies the variables that should be decayed.
If `None`, all variables that are optimized are decayed.
Usage example:
```python
# MyAdamW is a new class
MyAdamW = extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)
# Create a MyAdamW object
optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001)
# update var1, var2 but only decay var1
optimizer.minimize(loss, var_list=[var1, var2], decay_variables=[var1])
Note: this extension decays weights BEFORE applying the update based
on the gradient, i.e. this extension only has the desired behaviour for
optimizers which do not depend on the value of 'var' in the update step!
Note: when applying a decay to the learning rate, be sure to manually apply
the decay to the `weight_decay` as well. For example:
```python
step = tf.Variable(0, trainable=False)
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
[10000, 15000], [1e-0, 1e-1, 1e-2])
# lr and wd can be a function or a tensor
lr = 1e-1 * schedule(step)
wd = lambda: 1e-4 * schedule(step)
# ...
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
```
Note: you might want to register your own custom optimizer using
`tf.keras.utils.get_custom_objects()`.
Args:
base_optimizer: An optimizer class that inherits from
tf.optimizers.Optimizer.
Returns:
A new optimizer class that inherits from DecoupledWeightDecayExtension
and base_optimizer.
"""
class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
base_optimizer):
"""Base_optimizer with decoupled weight decay.
This class computes the update step of `base_optimizer` and
additionally decays the variable with the weight decay being
decoupled from the optimization steps w.r.t. to the loss
function, as described by Loshchilov & Hutter
(https://arxiv.org/pdf/1711.05101.pdf). For SGD variants, this
simplifies hyperparameter search since it decouples the settings
of weight decay and learning rate. For adaptive gradient
algorithms, it regularizes variables with large gradients more
than L2 regularization would, which was shown to yield better
training loss and generalization error in the paper above.
"""
def __init__(self, weight_decay, *args, **kwargs):
# super delegation is necessary here
super(OptimizerWithDecoupledWeightDecay, self).__init__(
weight_decay, *args, **kwargs)
return OptimizerWithDecoupledWeightDecay
class SGDW(DecoupledWeightDecayExtension, tf.keras.optimizers.SGD):
"""Optimizer that implements the Momentum algorithm with weight_decay.
This is an implementation of the SGDW optimizer described in "Decoupled
Weight Decay Regularization" by Loshchilov & Hutter
(https://arxiv.org/abs/1711.05101)
([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
It computes the update step of `tf.keras.optimizers.SGD` and additionally
decays the variable. Note that this is different from adding
L2 regularization on the variables to the loss. Decoupling the weight decay
from other hyperparameters (in particular the learning rate) simplifies
hyperparameter search.
For further information see the documentation of the SGD Optimizer.
This optimizer can also be instantiated as
```python
extend_with_decoupled_weight_decay(tf.keras.optimizers.SGD,
weight_decay=weight_decay)
```
Note: when applying a decay to the learning rate, be sure to manually apply
the decay to the `weight_decay` as well. For example:
```python
step = tf.Variable(0, trainable=False)
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
[10000, 15000], [1e-0, 1e-1, 1e-2])
# lr and wd can be a function or a tensor
lr = 1e-1 * schedule(step)
wd = lambda: 1e-4 * schedule(step)
# ...
optimizer = tfa.optimizers.SGDW(
learning_rate=lr, weight_decay=wd, momentum=0.9)
```
"""
def __init__(self,
weight_decay,
learning_rate=0.001,
momentum=0.0,
nesterov=False,
name='SGDW',
**kwargs):
"""Construct a new SGDW optimizer.
For further information see the documentation of the SGD Optimizer.
Args:
learning_rate: float hyperparameter >= 0. Learning rate.
momentum: float hyperparameter >= 0 that accelerates SGD in the
relevant direction and dampens oscillations.
nesterov: boolean. Whether to apply Nesterov momentum.
name: Optional name prefix for the operations created when applying
gradients. Defaults to 'SGD'.
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
norm; `clipvalue` is clip gradients by value, `decay` is
included for backward compatibility to allow time inverse decay
of learning rate. `lr` is included for backward compatibility,
recommended to use `learning_rate` instead.
"""
super(SGDW, self).__init__(
weight_decay,
learning_rate=learning_rate,
momentum=momentum,
nesterov=nesterov,
name=name,
**kwargs)
class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam):
"""Optimizer that implements the Adam algorithm with weight decay.
This is an implementation of the AdamW optimizer described in "Decoupled
Weight Decay Regularization" by Loshchilov & Hutter
(https://arxiv.org/abs/1711.05101)
([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
It computes the update step of `tf.keras.optimizers.Adam` and additionally
decays the variable. Note that this is different from adding L2
regularization on the variables to the loss: it regularizes variables with
large gradients more than L2 regularization would, which was shown to yield
better training loss and generalization error in the paper above.
For further information see the documentation of the Adam Optimizer.
This optimizer can also be instantiated as
```python
extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam,
weight_decay=weight_decay)
```
Note: when applying a decay to the learning rate, be sure to manually apply
the decay to the `weight_decay` as well. For example:
```python
step = tf.Variable(0, trainable=False)
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
[10000, 15000], [1e-0, 1e-1, 1e-2])
# lr and wd can be a function or a tensor
lr = 1e-1 * schedule(step)
wd = lambda: 1e-4 * schedule(step)
# ...
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
```
"""
def __init__(self,
weight_decay,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-07,
amsgrad=False,
name="AdamW",
**kwargs):
"""Construct a new AdamW optimizer.
For further information see the documentation of the Adam Optimizer.
Args:
weight_decay: A Tensor or a floating point value. The weight decay.
learning_rate: A Tensor or a floating point value. The learning
rate.
beta_1: A float value or a constant float tensor. The exponential
decay rate for the 1st moment estimates.
beta_2: A float value or a constant float tensor. The exponential
decay rate for the 2nd moment estimates.
epsilon: A small constant for numerical stability. This epsilon is
"epsilon hat" in the Kingma and Ba paper (in the formula just
before Section 2.1), not the epsilon in Algorithm 1 of the
paper.
amsgrad: boolean. Whether to apply AMSGrad variant of this
algorithm from the paper "On the Convergence of Adam and
beyond".
name: Optional name for the operations created when applying
gradients. Defaults to "AdamW".
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
norm; `clipvalue` is clip gradients by value, `decay` is
included for backward compatibility to allow time inverse decay
of learning rate. `lr` is included for backward compatibility,
recommended to use `learning_rate` instead.
"""
super(AdamW, self).__init__(
weight_decay,
learning_rate=learning_rate,
beta_1=beta_1,
beta_2=beta_2,
epsilon=epsilon,
amsgrad=amsgrad,
name=name,
**kwargs)