diff --git a/configs/archs/CIFAR-SIM05.config b/configs/archs/CIFAR-SIM05.config new file mode 100644 index 0000000..c0f6f7b --- /dev/null +++ b/configs/archs/CIFAR-SIM05.config @@ -0,0 +1,7 @@ +{ + "dataset" : ["str", "cifar"], + "arch" : ["str", "simres"], + "depth" : ["int", 5], + "super_type": ["str" , "basic"], + "zero_init_residual" : ["bool", "0"] +} diff --git a/exps-tf/GDAS.py b/exps-tf/GDAS.py new file mode 100644 index 0000000..7ca3a45 --- /dev/null +++ b/exps-tf/GDAS.py @@ -0,0 +1,144 @@ +# CUDA_VISIBLE_DEVICES=0 python exps-tf/GDAS.py +import os, sys, time, random, argparse +import tensorflow as tf +from pathlib import Path + +lib_dir = (Path(__file__).parent / '..' / 'lib').resolve() +if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) + +# self-lib +from tf_models import get_cell_based_tiny_net +from tf_optimizers import SGDW, AdamW +from config_utils import dict2config +from log_utils import time_string +from models import CellStructure + + +def pre_process(image_a, label_a, image_b, label_b): + def standard_func(image): + x = tf.pad(image, [[4, 4], [4, 4], [0, 0]]) + x = tf.image.random_crop(x, [32, 32, 3]) + x = tf.image.random_flip_left_right(x) + return x + return standard_func(image_a), label_a, standard_func(image_b), label_b + + +def main(xargs): + cifar10 = tf.keras.datasets.cifar10 + + (x_train, y_train), (x_test, y_test) = cifar10.load_data() + x_train, x_test = x_train / 255.0, x_test / 255.0 + x_train, x_test = x_train.astype('float32'), x_test.astype('float32') + + # Add a channels dimension + all_indexes = list(range(x_train.shape[0])) + random.shuffle(all_indexes) + s_train_idxs, s_valid_idxs = all_indexes[::2], all_indexes[1::2] + search_train_x, search_train_y = x_train[s_train_idxs], y_train[s_train_idxs] + search_valid_x, search_valid_y = x_train[s_valid_idxs], y_train[s_valid_idxs] + #x_train, x_test = x_train[..., tf.newaxis], x_test[..., tf.newaxis] + + # Use tf.data + #train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(64) + search_ds = tf.data.Dataset.from_tensor_slices((search_train_x, search_train_y, search_valid_x, search_valid_y)) + search_ds = search_ds.map(pre_process).shuffle(1000).batch(64) + + test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32) + + # Create an instance of the model + config = dict2config({'name': 'GDAS', + 'C' : xargs.channel, 'N': xargs.num_cells, 'max_nodes': xargs.max_nodes, + 'num_classes': 10, 'space': 'nas-bench-102', 'affine': True}, None) + model = get_cell_based_tiny_net(config) + #import pdb; pdb.set_trace() + #model.build(((64, 32, 32, 3), (1,))) + #for x in model.trainable_variables: + # print('{:30s} : {:}'.format(x.name, x.shape)) + # Choose optimizer + loss_object = tf.keras.losses.SparseCategoricalCrossentropy() + w_optimizer = SGDW(learning_rate=xargs.w_lr, weight_decay=xargs.w_weight_decay, momentum=xargs.w_momentum, nesterov=True) + a_optimizer = AdamW(learning_rate=xargs.arch_learning_rate, weight_decay=xargs.arch_weight_decay, beta_1=0.5, beta_2=0.999, epsilon=1e-07) + #w_optimizer = tf.keras.optimizers.SGD(learning_rate=0.025, momentum=0.9, nesterov=True) + #a_optimizer = tf.keras.optimizers.AdamW(learning_rate=xargs.arch_learning_rate, beta_1=0.5, beta_2=0.999, epsilon=1e-07) + #### + # metrics + train_loss = tf.keras.metrics.Mean(name='train_loss') + train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') + valid_loss = tf.keras.metrics.Mean(name='valid_loss') + valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='valid_accuracy') + test_loss = tf.keras.metrics.Mean(name='test_loss') + test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') + + @tf.function + def search_step(train_images, train_labels, valid_images, valid_labels, tf_tau): + # optimize weights + with tf.GradientTape() as tape: + predictions = model(train_images, tf_tau, True) + w_loss = loss_object(train_labels, predictions) + net_w_param = model.get_weights() + gradients = tape.gradient(w_loss, net_w_param) + w_optimizer.apply_gradients(zip(gradients, net_w_param)) + train_loss(w_loss) + train_accuracy(train_labels, predictions) + # optimize alphas + with tf.GradientTape() as tape: + predictions = model(valid_images, tf_tau, True) + a_loss = loss_object(valid_labels, predictions) + net_a_param = model.get_alphas() + gradients = tape.gradient(a_loss, net_a_param) + a_optimizer.apply_gradients(zip(gradients, net_a_param)) + valid_loss(a_loss) + valid_accuracy(valid_labels, predictions) + + # TEST + @tf.function + def test_step(images, labels): + predictions = model(images) + t_loss = loss_object(labels, predictions) + + test_loss(t_loss) + test_accuracy(labels, predictions) + + print('{:} start searching with {:} epochs ({:} batches per epoch).'.format(time_string(), xargs.epochs, tf.data.experimental.cardinality(search_ds).numpy())) + + for epoch in range(xargs.epochs): + # Reset the metrics at the start of the next epoch + train_loss.reset_states() ; train_accuracy.reset_states() + test_loss.reset_states() ; test_accuracy.reset_states() + cur_tau = xargs.tau_max - (xargs.tau_max-xargs.tau_min) * epoch / (xargs.epochs-1) + tf_tau = tf.cast(cur_tau, dtype=tf.float32, name='tau') + + for trn_imgs, trn_labels, val_imgs, val_labels in search_ds: + search_step(trn_imgs, trn_labels, val_imgs, val_labels, tf_tau) + genotype = model.genotype() + genotype = CellStructure(genotype) + + #for test_images, test_labels in test_ds: + # test_step(test_images, test_labels) + + template = '{:} Epoch {:03d}/{:03d}, Train-Loss: {:.3f}, Train-Accuracy: {:.2f}%, Valid-Loss: {:.3f}, Valid-Accuracy: {:.2f}% | tau={:.3f}' + print(template.format(time_string(), epoch+1, xargs.epochs, + train_loss.result(), + train_accuracy.result()*100, + valid_loss.result(), + valid_accuracy.result()*100, + cur_tau)) + print('{:} genotype : {:}\n{:}\n'.format(time_string(), genotype, model.get_np_alphas())) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='NAS-Bench-102', formatter_class=argparse.ArgumentDefaultsHelpFormatter) + # training details + parser.add_argument('--epochs' , type=int , default= 250 , help='') + parser.add_argument('--tau_max' , type=float, default= 10 , help='') + parser.add_argument('--tau_min' , type=float, default= 0.1 , help='') + parser.add_argument('--w_lr' , type=float, default= 0.025, help='') + parser.add_argument('--w_weight_decay' , type=float, default=0.0005, help='') + parser.add_argument('--w_momentum' , type=float, default= 0.9 , help='') + parser.add_argument('--arch_learning_rate', type=float, default=0.0003, help='') + parser.add_argument('--arch_weight_decay' , type=float, default=0.001, help='') + # marco structure + parser.add_argument('--channel' , type=int , default=16, help='') + parser.add_argument('--num_cells' , type=int , default= 5, help='') + parser.add_argument('--max_nodes' , type=int , default= 4, help='') + args = parser.parse_args() + main( args ) diff --git a/exps/vis/test.py b/exps/vis/test.py index 17ccb95..99cd31a 100644 --- a/exps/vis/test.py +++ b/exps/vis/test.py @@ -6,7 +6,6 @@ import numpy as np from collections import OrderedDict lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve() if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) -from graphviz import Digraph def test_nas_api(): @@ -29,6 +28,7 @@ OPS = ['skip-connect', 'conv-1x1', 'conv-3x3', 'pool-3x3'] COLORS = ['chartreuse' , 'cyan' , 'navyblue', 'chocolate1'] def plot(filename): + from graphviz import Digraph g = Digraph( format='png', edge_attr=dict(fontsize='20', fontname="times"), @@ -53,6 +53,26 @@ def plot(filename): g.render(filename, cleanup=True, view=False) +def test_auto_grad(): + class Net(torch.nn.Module): + def __init__(self, iS): + super(Net, self).__init__() + self.layer = torch.nn.Linear(iS, 1) + def forward(self, inputs): + outputs = self.layer(inputs) + outputs = torch.exp(outputs) + return outputs.mean() + net = Net(10) + inputs = torch.rand(256, 10) + loss = net(inputs) + first_order_grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True, create_graph=True) + first_order_grads = torch.cat([x.view(-1) for x in first_order_grads]) + second_order_grads = [] + for grads in first_order_grads: + s_grads = torch.autograd.grad(grads, net.parameters()) + second_order_grads.append( s_grads ) + if __name__ == '__main__': - test_nas_api() - for i in range(200): plot('{:04d}'.format(i)) + #test_nas_api() + #for i in range(200): plot('{:04d}'.format(i)) + test_auto_grad() diff --git a/lib/log_utils/__init__.py b/lib/log_utils/__init__.py index 0c8858a..c491293 100644 --- a/lib/log_utils/__init__.py +++ b/lib/log_utils/__init__.py @@ -1,7 +1,8 @@ ################################################## # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # ################################################## -from .logger import Logger -from .print_logger import PrintLogger +# every package does not rely on pytorch or tensorflow +# I tried to list all dependency here: os, sys, time, numpy, (possibly) matplotlib +from .logger import Logger, PrintLogger from .meter import AverageMeter -from .time_utils import time_for_file, time_string, time_string_short, time_print, convert_size2str, convert_secs2time +from .time_utils import time_for_file, time_string, time_string_short, time_print, convert_secs2time diff --git a/lib/log_utils/logger.py b/lib/log_utils/logger.py index 02c368c..e60c78f 100644 --- a/lib/log_utils/logger.py +++ b/lib/log_utils/logger.py @@ -1,9 +1,6 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +################################################## from pathlib import Path import importlib, warnings import os, sys, time, numpy as np @@ -16,6 +13,19 @@ if importlib.util.find_spec('tensorflow'): import tensorflow as tf +class PrintLogger(object): + + def __init__(self): + """Create a summary writer logging to log_dir.""" + self.name = 'PrintLogger' + + def log(self, string): + print (string) + + def close(self): + print ('-'*30 + ' close printer ' + '-'*30) + + class Logger(object): def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False): diff --git a/lib/log_utils/meter.py b/lib/log_utils/meter.py index 3138fec..cbb9dd1 100644 --- a/lib/log_utils/meter.py +++ b/lib/log_utils/meter.py @@ -1,4 +1,3 @@ -import time, sys import numpy as np diff --git a/lib/log_utils/print_logger.py b/lib/log_utils/print_logger.py deleted file mode 100644 index 5dc5b14..0000000 --- a/lib/log_utils/print_logger.py +++ /dev/null @@ -1,14 +0,0 @@ -import os, sys, time - - -class PrintLogger(object): - - def __init__(self): - """Create a summary writer logging to log_dir.""" - self.name = 'PrintLogger' - - def log(self, string): - print (string) - - def close(self): - print ('-'*30 + ' close printer ' + '-'*30) diff --git a/lib/log_utils/time_utils.py b/lib/log_utils/time_utils.py index 7886fcc..e38461f 100644 --- a/lib/log_utils/time_utils.py +++ b/lib/log_utils/time_utils.py @@ -1,37 +1,27 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +################################################## import time, sys import numpy as np def time_for_file(): ISOTIMEFORMAT='%d-%h-at-%H-%M-%S' - return '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) + return '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) def time_string(): ISOTIMEFORMAT='%Y-%m-%d %X' - string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) + string = '[{:}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) return string def time_string_short(): ISOTIMEFORMAT='%Y%m%d' - string = '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) + string = '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) return string def time_print(string, is_print=True): if (is_print): print('{} : {}'.format(time_string(), string)) -def convert_size2str(torch_size): - dims = len(torch_size) - string = '[' - for idim in range(dims): - string = string + ' {}'.format(torch_size[idim]) - return string + ']' - def convert_secs2time(epoch_time, return_str=False): need_hour = int(epoch_time / 3600) need_mins = int((epoch_time - 3600*need_hour) / 60) diff --git a/lib/models/__init__.py b/lib/models/__init__.py index 0a50df7..f19dbae 100644 --- a/lib/models/__init__.py +++ b/lib/models/__init__.py @@ -1,7 +1,6 @@ ################################################## # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # ################################################## -import torch from os import path as osp __all__ = ['change_key', 'get_cell_based_tiny_net', 'get_search_spaces', 'get_cifar_models', 'get_imagenet_models', \ @@ -126,6 +125,11 @@ def obtain_search_model(config): elif config.search_mode == 'shape': return SearchShapeCifarResNet(config.module, config.depth, config.class_num) else: raise ValueError('invalid search mode : {:}'.format(config.search_mode)) + elif config.arch == 'simres': + from .shape_searchs import SearchWidthSimResNet + if config.search_mode == 'width': + return SearchWidthSimResNet(config.depth, config.class_num) + else: raise ValueError('invalid search mode : {:}'.format(config.search_mode)) else: raise ValueError('invalid arch : {:} for dataset [{:}]'.format(config.arch, config.dataset)) elif config.dataset == 'imagenet': @@ -140,6 +144,7 @@ def obtain_search_model(config): def load_net_from_checkpoint(checkpoint): + import torch assert osp.isfile(checkpoint), 'checkpoint {:} does not exist'.format(checkpoint) checkpoint = torch.load(checkpoint) model_config = dict2config(checkpoint['model-config'], None) diff --git a/lib/models/shape_searchs/SearchSimResNet_width.py b/lib/models/shape_searchs/SearchSimResNet_width.py new file mode 100644 index 0000000..dbd9cad --- /dev/null +++ b/lib/models/shape_searchs/SearchSimResNet_width.py @@ -0,0 +1,316 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +################################################## +import math, torch +import torch.nn as nn +from ..initialization import initialize_resnet +from ..SharedUtils import additive_func +from .SoftSelect import select2withP, ChannelWiseInter +from .SoftSelect import linear_forward +from .SoftSelect import get_width_choices as get_choices + + +def conv_forward(inputs, conv, choices): + iC = conv.in_channels + fill_size = list(inputs.size()) + fill_size[1] = iC - fill_size[1] + filled = torch.zeros(fill_size, device=inputs.device) + xinputs = torch.cat((inputs, filled), dim=1) + outputs = conv(xinputs) + selecteds = [outputs[:,:oC] for oC in choices] + return selecteds + + +class ConvBNReLU(nn.Module): + num_conv = 1 + def __init__(self, nIn, nOut, kernel, stride, padding, bias, has_avg, has_bn, has_relu): + super(ConvBNReLU, self).__init__() + self.InShape = None + self.OutShape = None + self.choices = get_choices(nOut) + self.register_buffer('choices_tensor', torch.Tensor( self.choices )) + + if has_avg : self.avg = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + else : self.avg = None + self.conv = nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, padding=padding, dilation=1, groups=1, bias=bias) + #if has_bn : self.bn = nn.BatchNorm2d(nOut) + #else : self.bn = None + self.has_bn = has_bn + self.BNs = nn.ModuleList() + for i, _out in enumerate(self.choices): + self.BNs.append(nn.BatchNorm2d(_out)) + if has_relu: self.relu = nn.ReLU(inplace=True) + else : self.relu = None + self.in_dim = nIn + self.out_dim = nOut + self.search_mode = 'basic' + + def get_flops(self, channels, check_range=True, divide=1): + iC, oC = channels + if check_range: assert iC <= self.conv.in_channels and oC <= self.conv.out_channels, '{:} vs {:} | {:} vs {:}'.format(iC, self.conv.in_channels, oC, self.conv.out_channels) + assert isinstance(self.InShape, tuple) and len(self.InShape) == 2, 'invalid in-shape : {:}'.format(self.InShape) + assert isinstance(self.OutShape, tuple) and len(self.OutShape) == 2, 'invalid out-shape : {:}'.format(self.OutShape) + #conv_per_position_flops = self.conv.kernel_size[0] * self.conv.kernel_size[1] * iC * oC / self.conv.groups + conv_per_position_flops = (self.conv.kernel_size[0] * self.conv.kernel_size[1] * 1.0 / self.conv.groups) + all_positions = self.OutShape[0] * self.OutShape[1] + flops = (conv_per_position_flops * all_positions / divide) * iC * oC + if self.conv.bias is not None: flops += all_positions / divide + return flops + + def get_range(self): + return [self.choices] + + def forward(self, inputs): + if self.search_mode == 'basic': + return self.basic_forward(inputs) + elif self.search_mode == 'search': + return self.search_forward(inputs) + else: + raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) + + def search_forward(self, tuple_inputs): + assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) ) + inputs, expected_inC, probability, index, prob = tuple_inputs + index, prob = torch.squeeze(index).tolist(), torch.squeeze(prob) + probability = torch.squeeze(probability) + assert len(index) == 2, 'invalid length : {:}'.format(index) + # compute expected flop + #coordinates = torch.arange(self.x_range[0], self.x_range[1]+1).type_as(probability) + expected_outC = (self.choices_tensor * probability).sum() + expected_flop = self.get_flops([expected_inC, expected_outC], False, 1e6) + if self.avg : out = self.avg( inputs ) + else : out = inputs + # convolutional layer + out_convs = conv_forward(out, self.conv, [self.choices[i] for i in index]) + out_bns = [self.BNs[idx](out_conv) for idx, out_conv in zip(index, out_convs)] + # merge + out_channel = max([x.size(1) for x in out_bns]) + outA = ChannelWiseInter(out_bns[0], out_channel) + outB = ChannelWiseInter(out_bns[1], out_channel) + out = outA * prob[0] + outB * prob[1] + #out = additive_func(out_bns[0]*prob[0], out_bns[1]*prob[1]) + + if self.relu: out = self.relu( out ) + else : out = out + return out, expected_outC, expected_flop + + def basic_forward(self, inputs): + if self.avg : out = self.avg( inputs ) + else : out = inputs + conv = self.conv( out ) + if self.has_bn:out= self.BNs[-1]( conv ) + else : out = conv + if self.relu: out = self.relu( out ) + else : out = out + if self.InShape is None: + self.InShape = (inputs.size(-2), inputs.size(-1)) + self.OutShape = (out.size(-2) , out.size(-1)) + return out + + +class SimBlock(nn.Module): + expansion = 1 + num_conv = 1 + def __init__(self, inplanes, planes, stride): + super(SimBlock, self).__init__() + assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) + self.conv = ConvBNReLU(inplanes, planes, 3, stride, 1, False, has_avg=False, has_bn=True, has_relu=True) + if stride == 2: + self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=True, has_bn=False, has_relu=False) + elif inplanes != planes: + self.downsample = ConvBNReLU(inplanes, planes, 1, 1, 0, False, has_avg=False,has_bn=True , has_relu=False) + else: + self.downsample = None + self.out_dim = planes + self.search_mode = 'basic' + + def get_range(self): + return self.conv.get_range() + + def get_flops(self, channels): + assert len(channels) == 2, 'invalid channels : {:}'.format(channels) + flop_A = self.conv.get_flops([channels[0], channels[1]]) + if hasattr(self.downsample, 'get_flops'): + flop_C = self.downsample.get_flops([channels[0], channels[-1]]) + else: + flop_C = 0 + if channels[0] != channels[-1] and self.downsample is None: # this short-cut will be added during the infer-train + flop_C = channels[0] * channels[-1] * self.conv.OutShape[0] * self.conv.OutShape[1] + return flop_A + flop_C + + def forward(self, inputs): + if self.search_mode == 'basic' : return self.basic_forward(inputs) + elif self.search_mode == 'search': return self.search_forward(inputs) + else: raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) + + def search_forward(self, tuple_inputs): + assert isinstance(tuple_inputs, tuple) and len(tuple_inputs) == 5, 'invalid type input : {:}'.format( type(tuple_inputs) ) + inputs, expected_inC, probability, indexes, probs = tuple_inputs + assert indexes.size(0) == 1 and probs.size(0) == 1 and probability.size(0) == 1, 'invalid size : {:}, {:}, {:}'.format(indexes.size(), probs.size(), probability.size()) + out, expected_next_inC, expected_flop = self.conv( (inputs, expected_inC , probability[0], indexes[0], probs[0]) ) + if self.downsample is not None: + residual, _, expected_flop_c = self.downsample( (inputs, expected_inC , probability[-1], indexes[-1], probs[-1]) ) + else: + residual, expected_flop_c = inputs, 0 + out = additive_func(residual, out) + return out, expected_next_inC, sum([expected_flop, expected_flop_c]) + + def basic_forward(self, inputs): + basicblock = self.conv(inputs) + if self.downsample is not None: residual = self.downsample(inputs) + else : residual = inputs + out = additive_func(residual, basicblock) + return nn.functional.relu(out, inplace=True) + + + +class SearchWidthSimResNet(nn.Module): + + def __init__(self, depth, num_classes): + super(SearchWidthSimResNet, self).__init__() + + assert (depth - 2) % 3 == 0, 'depth should be one of 5, 8, 11, 14, ... instead of {:}'.format(depth) + layer_blocks = (depth - 2) // 3 + self.message = 'SearchWidthSimResNet : Depth : {:} , Layers for each block : {:}'.format(depth, layer_blocks) + self.num_classes = num_classes + self.channels = [16] + self.layers = nn.ModuleList( [ ConvBNReLU(3, 16, 3, 1, 1, False, has_avg=False, has_bn=True, has_relu=True) ] ) + self.InShape = None + for stage in range(3): + for iL in range(layer_blocks): + iC = self.channels[-1] + planes = 16 * (2**stage) + stride = 2 if stage > 0 and iL == 0 else 1 + module = SimBlock(iC, planes, stride) + self.channels.append( module.out_dim ) + self.layers.append ( module ) + self.message += "\nstage={:}, ilayer={:02d}/{:02d}, block={:03d}, iC={:3d}, oC={:3d}, stride={:}".format(stage, iL, layer_blocks, len(self.layers)-1, iC, module.out_dim, stride) + + self.avgpool = nn.AvgPool2d(8) + self.classifier = nn.Linear(module.out_dim, num_classes) + self.InShape = None + self.tau = -1 + self.search_mode = 'basic' + #assert sum(x.num_conv for x in self.layers) + 1 == depth, 'invalid depth check {:} vs {:}'.format(sum(x.num_conv for x in self.layers)+1, depth) + + # parameters for width + self.Ranges = [] + self.layer2indexRange = [] + for i, layer in enumerate(self.layers): + start_index = len(self.Ranges) + self.Ranges += layer.get_range() + self.layer2indexRange.append( (start_index, len(self.Ranges)) ) + assert len(self.Ranges) + 1 == depth, 'invalid depth check {:} vs {:}'.format(len(self.Ranges) + 1, depth) + + self.register_parameter('width_attentions', nn.Parameter(torch.Tensor(len(self.Ranges), get_choices(None)))) + nn.init.normal_(self.width_attentions, 0, 0.01) + self.apply(initialize_resnet) + + def arch_parameters(self): + return [self.width_attentions] + + def base_parameters(self): + return list(self.layers.parameters()) + list(self.avgpool.parameters()) + list(self.classifier.parameters()) + + def get_flop(self, mode, config_dict, extra_info): + if config_dict is not None: config_dict = config_dict.copy() + #weights = [F.softmax(x, dim=0) for x in self.width_attentions] + channels = [3] + for i, weight in enumerate(self.width_attentions): + if mode == 'genotype': + with torch.no_grad(): + probe = nn.functional.softmax(weight, dim=0) + C = self.Ranges[i][ torch.argmax(probe).item() ] + elif mode == 'max': + C = self.Ranges[i][-1] + elif mode == 'fix': + C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] ) + elif mode == 'random': + assert isinstance(extra_info, float), 'invalid extra_info : {:}'.format(extra_info) + with torch.no_grad(): + prob = nn.functional.softmax(weight, dim=0) + approximate_C = int( math.sqrt( extra_info ) * self.Ranges[i][-1] ) + for j in range(prob.size(0)): + prob[j] = 1 / (abs(j - (approximate_C-self.Ranges[i][j])) + 0.2) + C = self.Ranges[i][ torch.multinomial(prob, 1, False).item() ] + else: + raise ValueError('invalid mode : {:}'.format(mode)) + channels.append( C ) + flop = 0 + for i, layer in enumerate(self.layers): + s, e = self.layer2indexRange[i] + xchl = tuple( channels[s:e+1] ) + flop+= layer.get_flops(xchl) + # the last fc layer + flop += channels[-1] * self.classifier.out_features + if config_dict is None: + return flop / 1e6 + else: + config_dict['xchannels'] = channels + config_dict['super_type'] = 'infer-width' + config_dict['estimated_FLOP'] = flop / 1e6 + return flop / 1e6, config_dict + + def get_arch_info(self): + string = "for width, there are {:} attention probabilities.".format(len(self.width_attentions)) + discrepancy = [] + with torch.no_grad(): + for i, att in enumerate(self.width_attentions): + prob = nn.functional.softmax(att, dim=0) + prob = prob.cpu() ; selc = prob.argmax().item() ; prob = prob.tolist() + prob = ['{:.3f}'.format(x) for x in prob] + xstring = '{:03d}/{:03d}-th : {:}'.format(i, len(self.width_attentions), ' '.join(prob)) + logt = ['{:.3f}'.format(x) for x in att.cpu().tolist()] + xstring += ' || {:52s}'.format(' '.join(logt)) + prob = sorted( [float(x) for x in prob] ) + disc = prob[-1] - prob[-2] + xstring += ' || dis={:.2f} || select={:}/{:}'.format(disc, selc, len(prob)) + discrepancy.append( disc ) + string += '\n{:}'.format(xstring) + return string, discrepancy + + def set_tau(self, tau_max, tau_min, epoch_ratio): + assert epoch_ratio >= 0 and epoch_ratio <= 1, 'invalid epoch-ratio : {:}'.format(epoch_ratio) + tau = tau_min + (tau_max-tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2 + self.tau = tau + + def get_message(self): + return self.message + + def forward(self, inputs): + if self.search_mode == 'basic': + return self.basic_forward(inputs) + elif self.search_mode == 'search': + return self.search_forward(inputs) + else: + raise ValueError('invalid search_mode = {:}'.format(self.search_mode)) + + def search_forward(self, inputs): + flop_probs = nn.functional.softmax(self.width_attentions, dim=1) + selected_widths, selected_probs = select2withP(self.width_attentions, self.tau) + with torch.no_grad(): + selected_widths = selected_widths.cpu() + + x, last_channel_idx, expected_inC, flops = inputs, 0, 3, [] + for i, layer in enumerate(self.layers): + selected_w_index = selected_widths[last_channel_idx: last_channel_idx+layer.num_conv] + selected_w_probs = selected_probs[last_channel_idx: last_channel_idx+layer.num_conv] + layer_prob = flop_probs[last_channel_idx: last_channel_idx+layer.num_conv] + x, expected_inC, expected_flop = layer( (x, expected_inC, layer_prob, selected_w_index, selected_w_probs) ) + last_channel_idx += layer.num_conv + flops.append( expected_flop ) + flops.append( expected_inC * (self.classifier.out_features*1.0/1e6) ) + features = self.avgpool(x) + features = features.view(features.size(0), -1) + logits = linear_forward(features, self.classifier) + return logits, torch.stack( [sum(flops)] ) + + def basic_forward(self, inputs): + if self.InShape is None: self.InShape = (inputs.size(-2), inputs.size(-1)) + x = inputs + for i, layer in enumerate(self.layers): + x = layer( x ) + features = self.avgpool(x) + features = features.view(features.size(0), -1) + logits = self.classifier(features) + return features, logits diff --git a/lib/models/shape_searchs/__init__.py b/lib/models/shape_searchs/__init__.py index 91a58f4..554f035 100644 --- a/lib/models/shape_searchs/__init__.py +++ b/lib/models/shape_searchs/__init__.py @@ -4,4 +4,5 @@ from .SearchCifarResNet_width import SearchWidthCifarResNet from .SearchCifarResNet_depth import SearchDepthCifarResNet from .SearchCifarResNet import SearchShapeCifarResNet +from .SearchSimResNet_width import SearchWidthSimResNet from .SearchImagenetResNet import SearchShapeImagenetResNet diff --git a/lib/tf_models/__init__.py b/lib/tf_models/__init__.py new file mode 100644 index 0000000..ac05da5 --- /dev/null +++ b/lib/tf_models/__init__.py @@ -0,0 +1,32 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +################################################## +import torch +from os import path as osp + +__all__ = ['get_cell_based_tiny_net', 'get_search_spaces'] + + +# the cell-based NAS models +def get_cell_based_tiny_net(config): + group_names = ['GDAS'] + if config.name in group_names: + from .cell_searchs import nas_super_nets + from .cell_operations import SearchSpaceNames + if isinstance(config.space, str): search_space = SearchSpaceNames[config.space] + else: search_space = config.space + return nas_super_nets[config.name]( + config.C, config.N, config.max_nodes, + config.num_classes, search_space, config.affine) + else: + raise ValueError('invalid network name : {:}'.format(config.name)) + + +# obtain the search space, i.e., a dict mapping the operation name into a python-function for this op +def get_search_spaces(xtype, name): + if xtype == 'cell': + from .cell_operations import SearchSpaceNames + assert name in SearchSpaceNames, 'invalid name [{:}] in {:}'.format(name, SearchSpaceNames.keys()) + return SearchSpaceNames[name] + else: + raise ValueError('invalid search-space type is {:}'.format(xtype)) diff --git a/lib/tf_models/cell_operations.py b/lib/tf_models/cell_operations.py new file mode 100644 index 0000000..a98e190 --- /dev/null +++ b/lib/tf_models/cell_operations.py @@ -0,0 +1,120 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +################################################## +import tensorflow as tf + +__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames'] + +OPS = { + 'none' : lambda C_in, C_out, stride, affine: Zero(C_in, C_out, stride), + 'avg_pool_3x3': lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg', affine), + 'nor_conv_1x1': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 1, stride, affine), + 'nor_conv_3x3': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 3, stride, affine), + 'nor_conv_5x5': lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, 5, stride, affine), + 'skip_connect': lambda C_in, C_out, stride, affine: Identity(C_in, C_out, stride) +} + +NAS_BENCH_102 = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] + +SearchSpaceNames = { + 'nas-bench-102': NAS_BENCH_102, + } + + +class POOLING(tf.keras.layers.Layer): + + def __init__(self, C_in, C_out, stride, mode, affine): + super(POOLING, self).__init__() + if C_in == C_out: + self.preprocess = None + else: + self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, affine) + if mode == 'avg' : self.op = tf.keras.layers.AvgPool2D((3,3), strides=stride, padding='same') + elif mode == 'max': self.op = tf.keras.layers.MaxPool2D((3,3), strides=stride, padding='same') + else : raise ValueError('Invalid mode={:} in POOLING'.format(mode)) + + def call(self, inputs, training): + if self.preprocess: x = self.preprocess(inputs) + else : x = inputs + return self.op(x) + + +class Identity(tf.keras.layers.Layer): + def __init__(self, C_in, C_out, stride): + super(Identity, self).__init__() + if C_in != C_out or stride != 1: + self.layer = tf.keras.layers.Conv2D(C_out, 3, stride, padding='same', use_bias=False) + else: + self.layer = None + + def call(self, inputs, training): + x = inputs + if self.layer is not None: + x = self.layer(x) + return x + + + +class Zero(tf.keras.layers.Layer): + def __init__(self, C_in, C_out, stride): + super(Zero, self).__init__() + if C_in != C_out: + self.layer = tf.keras.layers.Conv2D(C_out, 1, stride, padding='same', use_bias=False) + elif stride != 1: + self.layer = tf.keras.layers.AvgPool2D((stride,stride), None, padding="same") + else: + self.layer = None + + def call(self, inputs, training): + x = tf.zeros_like(inputs) + if self.layer is not None: + x = self.layer(x) + return x + + +class ReLUConvBN(tf.keras.layers.Layer): + def __init__(self, C_in, C_out, kernel_size, strides, affine): + super(ReLUConvBN, self).__init__() + self.C_in = C_in + self.relu = tf.keras.activations.relu + self.conv = tf.keras.layers.Conv2D(C_out, kernel_size, strides, padding='same', use_bias=False) + self.bn = tf.keras.layers.BatchNormalization(center=affine, scale=affine) + + def call(self, inputs, training): + x = self.relu(inputs) + x = self.conv(x) + x = self.bn(x, training) + return x + + +class ResNetBasicblock(tf.keras.layers.Layer): + + def __init__(self, inplanes, planes, stride, affine=True): + super(ResNetBasicblock, self).__init__() + assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) + self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, affine) + self.conv_b = ReLUConvBN( planes, planes, 3, 1, affine) + if stride == 2: + self.downsample = tf.keras.Sequential([ + tf.keras.layers.AvgPool2D((stride,stride), None, padding="same"), + tf.keras.layers.Conv2D(planes, 1, 1, padding='same', use_bias=False)]) + elif inplanes != planes: + self.downsample = ReLUConvBN(inplanes, planes, 1, stride, affine) + else: + self.downsample = None + self.addition = tf.keras.layers.Add() + self.in_dim = inplanes + self.out_dim = planes + self.stride = stride + self.num_conv = 2 + + def call(self, inputs, training): + + basicblock = self.conv_a(inputs, training) + basicblock = self.conv_b(basicblock, training) + + if self.downsample is not None: + residual = self.downsample(inputs) + else: + residual = inputs + return self.addition([residual, basicblock]) diff --git a/lib/tf_models/cell_searchs/__init__.py b/lib/tf_models/cell_searchs/__init__.py new file mode 100644 index 0000000..479cb03 --- /dev/null +++ b/lib/tf_models/cell_searchs/__init__.py @@ -0,0 +1,6 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +################################################## +from .search_model_gdas import TinyNetworkGDAS + +nas_super_nets = {'GDAS': TinyNetworkGDAS} diff --git a/lib/tf_models/cell_searchs/search_cells.py b/lib/tf_models/cell_searchs/search_cells.py new file mode 100644 index 0000000..e93c84a --- /dev/null +++ b/lib/tf_models/cell_searchs/search_cells.py @@ -0,0 +1,50 @@ +################################################## +# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # +################################################## +import math, random +import tensorflow as tf +from copy import deepcopy +from ..cell_operations import OPS + + +class SearchCell(tf.keras.layers.Layer): + + def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False): + super(SearchCell, self).__init__() + + self.op_names = deepcopy(op_names) + self.max_nodes = max_nodes + self.in_dim = C_in + self.out_dim = C_out + self.edge_keys = [] + for i in range(1, max_nodes): + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + if j == 0: + xlists = [OPS[op_name](C_in , C_out, stride, affine) for op_name in op_names] + else: + xlists = [OPS[op_name](C_in , C_out, 1, affine) for op_name in op_names] + for k, op in enumerate(xlists): + setattr(self, '{:}.{:}'.format(node_str, k), op) + self.edge_keys.append( node_str ) + self.edge_keys = sorted(self.edge_keys) + self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} + self.num_edges = len(self.edge_keys) + + def call(self, inputs, weightss, training): + w_lst = tf.split(weightss, self.num_edges, 0) + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + edge_idx = self.edge2index[node_str] + op_outps = [] + for k, op_name in enumerate(self.op_names): + op = getattr(self, '{:}.{:}'.format(node_str, k)) + op_outps.append( op(nodes[j], training) ) + stack_op_outs = tf.stack(op_outps, axis=-1) + weighted_sums = tf.math.multiply(stack_op_outs, w_lst[edge_idx]) + inter_nodes.append( tf.math.reduce_sum(weighted_sums, axis=-1) ) + nodes.append( tf.math.add_n(inter_nodes) ) + return nodes[-1] diff --git a/lib/tf_models/cell_searchs/search_model_gdas.py b/lib/tf_models/cell_searchs/search_model_gdas.py new file mode 100644 index 0000000..1df26a6 --- /dev/null +++ b/lib/tf_models/cell_searchs/search_model_gdas.py @@ -0,0 +1,99 @@ +########################################################################### +# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 # +########################################################################### +import tensorflow as tf +import numpy as np +from copy import deepcopy +from ..cell_operations import ResNetBasicblock +from .search_cells import SearchCell + + +def sample_gumbel(shape, eps=1e-20): + U = tf.random.uniform(shape, minval=0, maxval=1) + return -tf.math.log(-tf.math.log(U + eps) + eps) + + +def gumbel_softmax(logits, temperature): + gumbel_softmax_sample = logits + sample_gumbel(tf.shape(logits)) + y = tf.nn.softmax(gumbel_softmax_sample / temperature) + return y + + +class TinyNetworkGDAS(tf.keras.Model): + + def __init__(self, C, N, max_nodes, num_classes, search_space, affine): + super(TinyNetworkGDAS, self).__init__() + self._C = C + self._layerN = N + self.max_nodes = max_nodes + self.stem = tf.keras.Sequential([ + tf.keras.layers.Conv2D(16, 3, 1, padding='same', use_bias=False), + tf.keras.layers.BatchNormalization()], name='stem') + + layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + + C_prev, num_edge, edge2index = C, None, None + for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): + cell_prefix = 'cell-{:03d}'.format(index) + #with tf.name_scope(cell_prefix) as scope: + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2) + else: + cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine) + if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index + else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) + C_prev = cell.out_dim + setattr(self, cell_prefix, cell) + self.num_layers = len(layer_reductions) + self.op_names = deepcopy( search_space ) + self.edge2index = edge2index + self.num_edge = num_edge + self.lastact = tf.keras.Sequential([ + tf.keras.layers.BatchNormalization(), + tf.keras.layers.ReLU(), + tf.keras.layers.GlobalAvgPool2D(), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(num_classes, activation='softmax')], name='lastact') + #self.arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) + arch_init = tf.random_normal_initializer(mean=0, stddev=0.001) + self.arch_parameters = tf.Variable(initial_value=arch_init(shape=(num_edge, len(search_space)), dtype='float32'), trainable=True, name='arch-encoding') + + def get_alphas(self): + xlist = self.trainable_variables + return [x for x in xlist if 'arch-encoding' in x.name] + + def get_weights(self): + xlist = self.trainable_variables + return [x for x in xlist if 'arch-encoding' not in x.name] + + def get_np_alphas(self): + arch_nps = self.arch_parameters.numpy() + arch_ops = np.exp(arch_nps) / np.sum(np.exp(arch_nps), axis=-1, keepdims=True) + return arch_ops + + def genotype(self): + genotypes, arch_nps = [], self.arch_parameters.numpy() + for i in range(1, self.max_nodes): + xlist = [] + for j in range(i): + node_str = '{:}<-{:}'.format(i, j) + weights = arch_nps[ self.edge2index[node_str] ] + op_name = self.op_names[ weights.argmax().item() ] + xlist.append((op_name, j)) + genotypes.append( tuple(xlist) ) + return genotypes + + # + def call(self, inputs, tau, training): + weightss = tf.cond(tau < 0, lambda: tf.nn.softmax(self.arch_parameters, axis=1), + lambda: gumbel_softmax(tf.math.log_softmax(self.arch_parameters, axis=1), tau)) + feature = self.stem(inputs, training) + for idx in range(self.num_layers): + cell = getattr(self, 'cell-{:03d}'.format(idx)) + if isinstance(cell, SearchCell): + feature = cell.call(feature, weightss, training) + else: + feature = cell(feature, training) + logits = self.lastact(feature, training) + return logits diff --git a/lib/tf_optimizers/__init__.py b/lib/tf_optimizers/__init__.py new file mode 100644 index 0000000..c72fe17 --- /dev/null +++ b/lib/tf_optimizers/__init__.py @@ -0,0 +1 @@ +from .weight_decay_optimizers import AdamW, SGDW diff --git a/lib/tf_optimizers/weight_decay_optimizers.py b/lib/tf_optimizers/weight_decay_optimizers.py new file mode 100644 index 0000000..b4e72dc --- /dev/null +++ b/lib/tf_optimizers/weight_decay_optimizers.py @@ -0,0 +1,422 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class to make optimizers weight decay ready.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +class DecoupledWeightDecayExtension(object): + """This class allows to extend optimizers with decoupled weight decay. + + It implements the decoupled weight decay described by Loshchilov & Hutter + (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is + decoupled from the optimization steps w.r.t. to the loss function. + For SGD variants, this simplifies hyperparameter search since it decouples + the settings of weight decay and learning rate. + For adaptive gradient algorithms, it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield + better training loss and generalization error in the paper above. + + This class alone is not an optimizer but rather extends existing + optimizers with decoupled weight decay. We explicitly define the two + examples used in the above paper (SGDW and AdamW), but in general this + can extend any OptimizerX by using + `extend_with_decoupled_weight_decay( + OptimizerX, weight_decay=weight_decay)`. + In order for it to work, it must be the first class the Optimizer with + weight decay inherits from, e.g. + + ```python + class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam): + def __init__(self, weight_decay, *args, **kwargs): + super(AdamW, self).__init__(weight_decay, *args, **kwargs). + ``` + + Note: this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of'var' in the update step! + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + step = tf.Variable(0, trainable=False) + schedule = tf.optimizers.schedules.PiecewiseConstantDecay( + [10000, 15000], [1e-0, 1e-1, 1e-2]) + # lr and wd can be a function or a tensor + lr = 1e-1 * schedule(step) + wd = lambda: 1e-4 * schedule(step) + + # ... + + optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) + ``` + """ + + def __init__(self, weight_decay, **kwargs): + """Extension class that adds weight decay to an optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value, the factor by + which a variable is decayed in the update step. + **kwargs: Optional list or tuple or set of `Variable` objects to + decay. + """ + wd = kwargs.pop('weight_decay', weight_decay) + super(DecoupledWeightDecayExtension, self).__init__(**kwargs) + self._decay_var_list = None # is set in minimize or apply_gradients + self._set_hyper('weight_decay', wd) + + def get_config(self): + config = super(DecoupledWeightDecayExtension, self).get_config() + config.update({ + 'weight_decay': + self._serialize_hyperparameter('weight_decay'), + }) + return config + + def minimize(self, + loss, + var_list, + grad_loss=None, + name=None, + decay_var_list=None): + """Minimize `loss` by updating `var_list`. + + This method simply computes gradient using `tf.GradientTape` and calls + `apply_gradients()`. If you want to process the gradient before + applying then call `tf.GradientTape` and `apply_gradients()` explicitly + instead of using this function. + + Args: + loss: A callable taking no arguments which returns the value to + minimize. + var_list: list or tuple of `Variable` objects to update to + minimize `loss`, or a callable returning the list or tuple of + `Variable` objects. Use callable when the variable list would + otherwise be incomplete before `minimize` since the variables + are created at the first time `loss` is called. + grad_loss: Optional. A `Tensor` holding the gradient computed for + `loss`. + decay_var_list: Optional list of variables to be decayed. Defaults + to all variables in var_list. + name: Optional name for the returned operation. + Returns: + An Operation that updates the variables in `var_list`. If + `global_step` was not `None`, that operation also increments + `global_step`. + Raises: + ValueError: If some of the variables are not `Variable` objects. + """ + self._decay_var_list = set(decay_var_list) if decay_var_list else False + return super(DecoupledWeightDecayExtension, self).minimize( + loss, var_list=var_list, grad_loss=grad_loss, name=name) + + def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None): + """Apply gradients to variables. + + This is the second part of `minimize()`. It returns an `Operation` that + applies gradients. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + name: Optional name for the returned operation. Default to the + name passed to the `Optimizer` constructor. + decay_var_list: Optional list of variables to be decayed. Defaults + to all variables in var_list. + Returns: + An `Operation` that applies the specified gradients. If + `global_step` was not None, that operation also increments + `global_step`. + Raises: + TypeError: If `grads_and_vars` is malformed. + ValueError: If none of the variables have gradients. + """ + self._decay_var_list = set(decay_var_list) if decay_var_list else False + return super(DecoupledWeightDecayExtension, self).apply_gradients( + grads_and_vars, name=name) + + def _decay_weights_op(self, var): + if not self._decay_var_list or var in self._decay_var_list: + return var.assign_sub( + self._get_hyper('weight_decay', var.dtype) * var, + self._use_locking) + return tf.no_op() + + def _decay_weights_sparse_op(self, var, indices): + if not self._decay_var_list or var in self._decay_var_list: + update = (-self._get_hyper('weight_decay', var.dtype) * tf.gather( + var, indices)) + return self._resource_scatter_add(var, indices, update) + return tf.no_op() + + # Here, we overwrite the apply functions that the base optimizer calls. + # super().apply_x resolves to the apply_x function of the BaseOptimizer. + + def _resource_apply_dense(self, grad, var): + with tf.control_dependencies([self._decay_weights_op(var)]): + return super(DecoupledWeightDecayExtension, + self)._resource_apply_dense(grad, var) + + def _resource_apply_sparse(self, grad, var, indices): + decay_op = self._decay_weights_sparse_op(var, indices) + with tf.control_dependencies([decay_op]): + return super(DecoupledWeightDecayExtension, + self)._resource_apply_sparse(grad, var, indices) + + +def extend_with_decoupled_weight_decay(base_optimizer): + """Factory function returning an optimizer class with decoupled weight + decay. + + Returns an optimizer class. An instance of the returned class computes the + update step of `base_optimizer` and additionally decays the weights. + E.g., the class returned by + `extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam)` is + equivalent to `tfa.optimizers.AdamW`. + + The API of the new optimizer class slightly differs from the API of the + base optimizer: + - The first argument to the constructor is the weight decay rate. + - `minimize` and `apply_gradients` accept the optional keyword argument + `decay_var_list`, which specifies the variables that should be decayed. + If `None`, all variables that are optimized are decayed. + + Usage example: + ```python + # MyAdamW is a new class + MyAdamW = extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam) + # Create a MyAdamW object + optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) + # update var1, var2 but only decay var1 + optimizer.minimize(loss, var_list=[var1, var2], decay_variables=[var1]) + + Note: this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of 'var' in the update step! + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + step = tf.Variable(0, trainable=False) + schedule = tf.optimizers.schedules.PiecewiseConstantDecay( + [10000, 15000], [1e-0, 1e-1, 1e-2]) + # lr and wd can be a function or a tensor + lr = 1e-1 * schedule(step) + wd = lambda: 1e-4 * schedule(step) + + # ... + + optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) + ``` + + Note: you might want to register your own custom optimizer using + `tf.keras.utils.get_custom_objects()`. + + Args: + base_optimizer: An optimizer class that inherits from + tf.optimizers.Optimizer. + + Returns: + A new optimizer class that inherits from DecoupledWeightDecayExtension + and base_optimizer. + """ + + class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, + base_optimizer): + """Base_optimizer with decoupled weight decay. + + This class computes the update step of `base_optimizer` and + additionally decays the variable with the weight decay being + decoupled from the optimization steps w.r.t. to the loss + function, as described by Loshchilov & Hutter + (https://arxiv.org/pdf/1711.05101.pdf). For SGD variants, this + simplifies hyperparameter search since it decouples the settings + of weight decay and learning rate. For adaptive gradient + algorithms, it regularizes variables with large gradients more + than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + """ + + def __init__(self, weight_decay, *args, **kwargs): + # super delegation is necessary here + super(OptimizerWithDecoupledWeightDecay, self).__init__( + weight_decay, *args, **kwargs) + + return OptimizerWithDecoupledWeightDecay + + +class SGDW(DecoupledWeightDecayExtension, tf.keras.optimizers.SGD): + """Optimizer that implements the Momentum algorithm with weight_decay. + + This is an implementation of the SGDW optimizer described in "Decoupled + Weight Decay Regularization" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + It computes the update step of `tf.keras.optimizers.SGD` and additionally + decays the variable. Note that this is different from adding + L2 regularization on the variables to the loss. Decoupling the weight decay + from other hyperparameters (in particular the learning rate) simplifies + hyperparameter search. + + For further information see the documentation of the SGD Optimizer. + + This optimizer can also be instantiated as + ```python + extend_with_decoupled_weight_decay(tf.keras.optimizers.SGD, + weight_decay=weight_decay) + ``` + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + step = tf.Variable(0, trainable=False) + schedule = tf.optimizers.schedules.PiecewiseConstantDecay( + [10000, 15000], [1e-0, 1e-1, 1e-2]) + # lr and wd can be a function or a tensor + lr = 1e-1 * schedule(step) + wd = lambda: 1e-4 * schedule(step) + + # ... + + optimizer = tfa.optimizers.SGDW( + learning_rate=lr, weight_decay=wd, momentum=0.9) + ``` + """ + + def __init__(self, + weight_decay, + learning_rate=0.001, + momentum=0.0, + nesterov=False, + name='SGDW', + **kwargs): + """Construct a new SGDW optimizer. + + For further information see the documentation of the SGD Optimizer. + + Args: + learning_rate: float hyperparameter >= 0. Learning rate. + momentum: float hyperparameter >= 0 that accelerates SGD in the + relevant direction and dampens oscillations. + nesterov: boolean. Whether to apply Nesterov momentum. + name: Optional name prefix for the operations created when applying + gradients. Defaults to 'SGD'. + **kwargs: keyword arguments. Allowed to be {`clipnorm`, + `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by + norm; `clipvalue` is clip gradients by value, `decay` is + included for backward compatibility to allow time inverse decay + of learning rate. `lr` is included for backward compatibility, + recommended to use `learning_rate` instead. + """ + super(SGDW, self).__init__( + weight_decay, + learning_rate=learning_rate, + momentum=momentum, + nesterov=nesterov, + name=name, + **kwargs) + + +class AdamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Adam): + """Optimizer that implements the Adam algorithm with weight decay. + + This is an implementation of the AdamW optimizer described in "Decoupled + Weight Decay Regularization" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + + It computes the update step of `tf.keras.optimizers.Adam` and additionally + decays the variable. Note that this is different from adding L2 + regularization on the variables to the loss: it regularizes variables with + large gradients more than L2 regularization would, which was shown to yield + better training loss and generalization error in the paper above. + + For further information see the documentation of the Adam Optimizer. + + This optimizer can also be instantiated as + ```python + extend_with_decoupled_weight_decay(tf.keras.optimizers.Adam, + weight_decay=weight_decay) + ``` + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + step = tf.Variable(0, trainable=False) + schedule = tf.optimizers.schedules.PiecewiseConstantDecay( + [10000, 15000], [1e-0, 1e-1, 1e-2]) + # lr and wd can be a function or a tensor + lr = 1e-1 * schedule(step) + wd = lambda: 1e-4 * schedule(step) + + # ... + + optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd) + ``` + """ + + def __init__(self, + weight_decay, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-07, + amsgrad=False, + name="AdamW", + **kwargs): + """Construct a new AdamW optimizer. + + For further information see the documentation of the Adam Optimizer. + + Args: + weight_decay: A Tensor or a floating point value. The weight decay. + learning_rate: A Tensor or a floating point value. The learning + rate. + beta_1: A float value or a constant float tensor. The exponential + decay rate for the 1st moment estimates. + beta_2: A float value or a constant float tensor. The exponential + decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just + before Section 2.1), not the epsilon in Algorithm 1 of the + paper. + amsgrad: boolean. Whether to apply AMSGrad variant of this + algorithm from the paper "On the Convergence of Adam and + beyond". + name: Optional name for the operations created when applying + gradients. Defaults to "AdamW". + **kwargs: keyword arguments. Allowed to be {`clipnorm`, + `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by + norm; `clipvalue` is clip gradients by value, `decay` is + included for backward compatibility to allow time inverse decay + of learning rate. `lr` is included for backward compatibility, + recommended to use `learning_rate` instead. + """ + super(AdamW, self).__init__( + weight_decay, + learning_rate=learning_rate, + beta_1=beta_1, + beta_2=beta_2, + epsilon=epsilon, + amsgrad=amsgrad, + name=name, + **kwargs)