Prototype generic nas model (cont.) for ENAS.

This commit is contained in:
D-X-Y 2020-07-19 11:25:37 +00:00
parent b9a5d2880f
commit 16c5651bdc
2 changed files with 172 additions and 12 deletions

View File

@ -20,6 +20,10 @@
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo random --rand_seed 777 # python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo random --rand_seed 777
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo random # python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo random
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random # python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo random
####
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas
###################################################################################### ######################################################################################
import os, sys, time, random, argparse import os, sys, time, random, argparse
import numpy as np import numpy as np
@ -130,6 +134,11 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
network.set_cal_mode('joint', None) network.set_cal_mode('joint', None)
elif algo == 'random': elif algo == 'random':
network.set_cal_mode('urs', None) network.set_cal_mode('urs', None)
elif algo == 'enas':
with torch.no_grad():
network.controller.eval()
_, _, sampled_arch = network.controller()
network.set_cal_mode('dynamic', sampled_arch)
else: else:
raise ValueError('Invalid algo name : {:}'.format(algo)) raise ValueError('Invalid algo name : {:}'.format(algo))
@ -153,16 +162,21 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
network.set_cal_mode('joint', None) network.set_cal_mode('joint', None)
elif algo == 'random': elif algo == 'random':
network.set_cal_mode('urs', None) network.set_cal_mode('urs', None)
else: elif algo != 'enas':
raise ValueError('Invalid algo name : {:}'.format(algo)) raise ValueError('Invalid algo name : {:}'.format(algo))
network.zero_grad() network.zero_grad()
if algo == 'darts-v2': if algo == 'darts-v2':
arch_loss, logits = backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets) arch_loss, logits = backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets)
a_optimizer.step()
elif algo == 'random' or algo == 'enas':
with torch.no_grad():
_, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets)
else: else:
_, logits = network(arch_inputs) _, logits = network(arch_inputs)
arch_loss = criterion(logits, arch_targets) arch_loss = criterion(logits, arch_targets)
arch_loss.backward() arch_loss.backward()
a_optimizer.step() a_optimizer.step()
# record # record
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5)) arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
arch_losses.update(arch_loss.item(), arch_inputs.size(0)) arch_losses.update(arch_loss.item(), arch_inputs.size(0))
@ -182,6 +196,76 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
def train_controller(xloader, network, criterion, optimizer, prev_baseline, epoch_str, print_freq, logger):
# config. (containing some necessary arg)
# baseline: The baseline score (i.e. average val_acc) from the previous epoch
data_time, batch_time = AverageMeter(), AverageMeter()
GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), time.time()
controller_num_aggregate = 20
controller_train_steps = 50
controller_bl_dec = 0.99
controller_entropy_weight = 0.0001
network.eval()
network.controller.train()
network.controller.zero_grad()
loader_iter = iter(xloader)
for step in range(controller_train_steps * controller_num_aggregate):
try:
inputs, targets = next(loader_iter)
except:
loader_iter = iter(xloader)
inputs, targets = next(loader_iter)
inputs = inputs.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True)
# measure data loading time
data_time.update(time.time() - xend)
log_prob, entropy, sampled_arch = network.controller()
with torch.no_grad():
network.set_cal_mode('dynamic', sampled_arch)
_, logits = network(inputs)
val_top1, val_top5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
val_top1 = val_top1.view(-1) / 100
reward = val_top1 + controller_entropy_weight * entropy
if prev_baseline is None:
baseline = val_top1
else:
baseline = prev_baseline - (1 - controller_bl_dec) * (prev_baseline - reward)
loss = -1 * log_prob * (reward - baseline)
# account
RewardMeter.update(reward.item())
BaselineMeter.update(baseline.item())
ValAccMeter.update(val_top1.item()*100)
LossMeter.update(loss.item())
EntropyMeter.update(entropy.item())
# Average gradient over controller_num_aggregate samples
loss = loss / controller_num_aggregate
loss.backward(retain_graph=True)
# measure elapsed time
batch_time.update(time.time() - xend)
xend = time.time()
if (step+1) % controller_num_aggregate == 0:
grad_norm = torch.nn.utils.clip_grad_norm_(network.controller.parameters(), 5.0)
GradnormMeter.update(grad_norm)
optimizer.step()
network.controller.zero_grad()
if step % print_freq == 0:
Sstr = '*Train-Controller* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, controller_train_steps * controller_num_aggregate)
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})'.format(loss=LossMeter, top1=ValAccMeter, reward=RewardMeter, basel=BaselineMeter)
Estr = 'Entropy={:.4f} ({:.4f})'.format(EntropyMeter.val, EntropyMeter.avg)
logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Estr)
return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg
def get_best_arch(xloader, network, n_samples, algo): def get_best_arch(xloader, network, n_samples, algo):
with torch.no_grad(): with torch.no_grad():
network.eval() network.eval()
@ -192,6 +276,11 @@ def get_best_arch(xloader, network, n_samples, algo):
elif algo.startswith('darts') or algo == 'gdas': elif algo.startswith('darts') or algo == 'gdas':
arch = network.genotype arch = network.genotype
archs, valid_accs = [arch], [] archs, valid_accs = [arch], []
elif algo == 'enas':
archs, valid_accs = [], []
for _ in range(n_samples):
_, _, sampled_arch = network.controller()
archs.append(sampled_arch)
else: else:
raise ValueError('Invalid algorithm name : {:}'.format(algo)) raise ValueError('Invalid algorithm name : {:}'.format(algo))
loader_iter = iter(xloader) loader_iter = iter(xloader)
@ -245,7 +334,7 @@ def main(xargs):
train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1) train_data, valid_data, xshape, class_num = get_datasets(xargs.dataset, xargs.data_path, -1)
config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger) config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, logger)
search_loader, _, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \ search_loader, train_loader, valid_loader = get_nas_search_loaders(train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/', \
(config.batch_size, config.test_batch_size), xargs.workers) (config.batch_size, config.test_batch_size), xargs.workers)
logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size)) logger.log('||||||| {:10s} ||||||| Search-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(xargs.dataset, len(search_loader), len(valid_loader), config.batch_size))
logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config)) logger.log('||||||| {:10s} ||||||| Config={:}'.format(xargs.dataset, config))
@ -263,7 +352,7 @@ def main(xargs):
logger.log('{:}'.format(search_model)) logger.log('{:}'.format(search_model))
w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config) w_optimizer, w_scheduler, criterion = get_optim_scheduler(search_model.weights, config)
a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay) a_optimizer = torch.optim.Adam(search_model.alphas, lr=xargs.arch_learning_rate, betas=(0.5, 0.999), weight_decay=xargs.arch_weight_decay, eps=xargs.arch_eps)
logger.log('w-optimizer : {:}'.format(w_optimizer)) logger.log('w-optimizer : {:}'.format(w_optimizer))
logger.log('a-optimizer : {:}'.format(a_optimizer)) logger.log('a-optimizer : {:}'.format(a_optimizer))
logger.log('w-scheduler : {:}'.format(w_scheduler)) logger.log('w-scheduler : {:}'.format(w_scheduler))
@ -288,6 +377,8 @@ def main(xargs):
start_epoch = last_info['epoch'] start_epoch = last_info['epoch']
checkpoint = torch.load(last_info['last_checkpoint']) checkpoint = torch.load(last_info['last_checkpoint'])
genotypes = checkpoint['genotypes'] genotypes = checkpoint['genotypes']
if xargs.algo == 'enas':
baseline = checkpoint['baseline']
valid_accuracies = checkpoint['valid_accuracies'] valid_accuracies = checkpoint['valid_accuracies']
search_model.load_state_dict( checkpoint['search_model'] ) search_model.load_state_dict( checkpoint['search_model'] )
w_scheduler.load_state_dict ( checkpoint['w_scheduler'] ) w_scheduler.load_state_dict ( checkpoint['w_scheduler'] )
@ -297,6 +388,7 @@ def main(xargs):
else: else:
logger.log("=> do not find the last-info file : {:}".format(last_info)) logger.log("=> do not find the last-info file : {:}".format(last_info))
start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: network.return_topK(1, True)[0]} start_epoch, valid_accuracies, genotypes = 0, {'best': -1}, {-1: network.return_topK(1, True)[0]}
baseline = None
# start training # start training
start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup start_time, search_time, epoch_time, total_epoch = time.time(), AverageMeter(), AverageMeter(), config.epochs + config.warmup
@ -312,9 +404,13 @@ def main(xargs):
search_time.update(time.time() - start_time) search_time.update(time.time() - start_time)
logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum)) logger.log('[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'.format(epoch_str, search_w_loss, search_w_top1, search_w_top5, search_time.sum))
logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5)) logger.log('[{:}] search [arch] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'.format(epoch_str, search_a_loss, search_a_top1, search_a_top5))
if xargs.algo == 'enas':
ctl_loss, ctl_acc, baseline, ctl_reward \
= train_controller(valid_loader, network, criterion, a_optimizer, baseline, epoch_str, xargs.print_freq, logger)
logger.log('[{:}] controller : loss={:}, acc={:}, baseline={:}, reward={:}'.format(epoch_str, ctl_loss, ctl_acc, baseline, ctl_reward))
genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.eval_candidate_num, xargs.algo) genotype, temp_accuracy = get_best_arch(valid_loader, network, xargs.eval_candidate_num, xargs.algo)
if xargs.algo == 'setn': if xargs.algo == 'setn' or xargs.algo == 'enas':
network.set_cal_mode('dynamic', genotype) network.set_cal_mode('dynamic', genotype)
elif xargs.algo == 'gdas': elif xargs.algo == 'gdas':
network.set_cal_mode('gdas', None) network.set_cal_mode('gdas', None)
@ -333,6 +429,7 @@ def main(xargs):
# save checkpoint # save checkpoint
save_path = save_checkpoint({'epoch' : epoch + 1, save_path = save_checkpoint({'epoch' : epoch + 1,
'args' : deepcopy(xargs), 'args' : deepcopy(xargs),
'baseline' : baseline,
'search_model': search_model.state_dict(), 'search_model': search_model.state_dict(),
'w_optimizer' : w_optimizer.state_dict(), 'w_optimizer' : w_optimizer.state_dict(),
'a_optimizer' : a_optimizer.state_dict(), 'a_optimizer' : a_optimizer.state_dict(),
@ -377,7 +474,6 @@ def main(xargs):
logger.close() logger.close()
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser("Weight sharing NAS methods to search for cells.") parser = argparse.ArgumentParser("Weight sharing NAS methods to search for cells.")
parser.add_argument('--data_path' , type=str, help='Path to dataset') parser.add_argument('--data_path' , type=str, help='Path to dataset')
@ -396,7 +492,8 @@ if __name__ == '__main__':
parser.add_argument('--config_path' , type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.') parser.add_argument('--config_path' , type=str, default='./configs/nas-benchmark/algos/weight-sharing.config', help='The path of configuration.')
# architecture leraning rate # architecture leraning rate
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding') parser.add_argument('--arch_weight_decay' , type=float, default=1e-3, help='weight decay for arch encoding')
parser.add_argument('--arch_eps' , type=float, default=1e-8, help='weight decay for arch encoding')
parser.add_argument('--drop_path_rate' , type=float, help='The drop path rate.') parser.add_argument('--drop_path_rate' , type=float, help='The drop path rate.')
# log # log
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)') parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')

View File

@ -5,11 +5,75 @@ import torch, random
import torch.nn as nn import torch.nn as nn
from copy import deepcopy from copy import deepcopy
from typing import Text from typing import Text
from torch.distributions.categorical import Categorical
from ..cell_operations import ResNetBasicblock, drop_path from ..cell_operations import ResNetBasicblock, drop_path
from .search_cells import NAS201SearchCell as SearchCell from .search_cells import NAS201SearchCell as SearchCell
from .genotypes import Structure from .genotypes import Structure
from .search_model_enas_utils import Controller
class Controller(nn.Module):
# we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py
def __init__(self, edge2index, op_names, max_nodes, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0):
super(Controller, self).__init__()
# assign the attributes
self.max_nodes = max_nodes
self.num_edge = len(edge2index)
self.edge2index = edge2index
self.num_ops = len(op_names)
self.op_names = op_names
self.lstm_size = lstm_size
self.lstm_N = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
# create parameters
self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size)))
self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N)
self.w_embd = nn.Embedding(self.num_ops, self.lstm_size)
self.w_pred = nn.Linear(self.lstm_size, self.num_ops)
nn.init.uniform_(self.input_vars , -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1)
nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1)
nn.init.uniform_(self.w_embd.weight , -0.1, 0.1)
nn.init.uniform_(self.w_pred.weight , -0.1, 0.1)
def convert_structure(self, _arch):
genotypes = []
for i in range(1, self.max_nodes):
xlist = []
for j in range(i):
node_str = '{:}<-{:}'.format(i, j)
op_index = _arch[self.edge2index[node_str]]
op_name = self.op_names[op_index]
xlist.append((op_name, j))
genotypes.append( tuple(xlist) )
return Structure(genotypes)
def forward(self):
inputs, h0 = self.input_vars, None
log_probs, entropys, sampled_arch = [], [], []
for iedge in range(self.num_edge):
outputs, h0 = self.w_lstm(inputs, h0)
logits = self.w_pred(outputs)
logits = logits / self.temperature
logits = self.tanh_constant * torch.tanh(logits)
# distribution
op_distribution = Categorical(logits=logits)
op_index = op_distribution.sample()
sampled_arch.append( op_index.item() )
op_log_prob = op_distribution.log_prob(op_index)
log_probs.append( op_log_prob.view(-1) )
op_entropy = op_distribution.entropy()
entropys.append( op_entropy.view(-1) )
# obtain the input embedding for the next step
inputs = self.w_embd(op_index)
return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), self.convert_structure(sampled_arch)
class GenericNAS201Model(nn.Module): class GenericNAS201Model(nn.Module):
@ -55,7 +119,7 @@ class GenericNAS201Model(nn.Module):
assert self._algo is None, 'This functioin can only be called once.' assert self._algo is None, 'This functioin can only be called once.'
self._algo = algo self._algo = algo
if algo == 'enas': if algo == 'enas':
self.controller = Controller(len(self.edge2index), len(self._op_names)) self.controller = Controller(self.edge2index, self._op_names, self._max_nodes)
else: else:
self.arch_parameters = nn.Parameter( 1e-3*torch.randn(self._num_edge, len(self._op_names)) ) self.arch_parameters = nn.Parameter( 1e-3*torch.randn(self._num_edge, len(self._op_names)) )
if algo == 'gdas': if algo == 'gdas':
@ -116,10 +180,9 @@ class GenericNAS201Model(nn.Module):
def show_alphas(self): def show_alphas(self):
with torch.no_grad(): with torch.no_grad():
if self._algo == 'enas': if self._algo == 'enas':
import pdb; pdb.set_trace() return 'w_pred :\n{:}'.format(self.controller.w_pred.weight)
print('-')
else: else:
return 'arch-parameters :\n{:}'.format( nn.functional.softmax(self.arch_parameters, dim=-1).cpu() ) return 'arch-parameters :\n{:}'.format(nn.functional.softmax(self.arch_parameters, dim=-1).cpu())
def extra_repr(self): def extra_repr(self):