Prototype generic nas model (cont.).
This commit is contained in:
parent
7ca2ca70b4
commit
c34620ab1b
@ -2,7 +2,7 @@
|
|||||||
"scheduler": ["str", "cos"],
|
"scheduler": ["str", "cos"],
|
||||||
"LR" : ["float", "0.025"],
|
"LR" : ["float", "0.025"],
|
||||||
"eta_min" : ["float", "0.001"],
|
"eta_min" : ["float", "0.001"],
|
||||||
"epochs" : ["int", "250"],
|
"epochs" : ["int", "150"],
|
||||||
"warmup" : ["int", "0"],
|
"warmup" : ["int", "0"],
|
||||||
"optim" : ["str", "SGD"],
|
"optim" : ["str", "SGD"],
|
||||||
"decay" : ["float", "0.0005"],
|
"decay" : ["float", "0.0005"],
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020 #
|
||||||
######################################################################################
|
######################################################################################
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --rand_seed 1
|
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --rand_seed 1
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1
|
# python ./exps/algos-v2/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo darts-v1 --drop_path_rate 0.3
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1
|
# python ./exps/algos-v2/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo darts-v1
|
||||||
####
|
####
|
||||||
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v2 --rand_seed 1
|
# python ./exps/algos-v2/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo darts-v2 --rand_seed 1
|
||||||
@ -30,6 +30,73 @@ from models import get_cell_based_tiny_net, get_search_spaces
|
|||||||
from nas_201_api import NASBench201API as API
|
from nas_201_api import NASBench201API as API
|
||||||
|
|
||||||
|
|
||||||
|
# The following three functions are used for DARTS-V2
|
||||||
|
def _concat(xs):
|
||||||
|
return torch.cat([x.view(-1) for x in xs])
|
||||||
|
|
||||||
|
|
||||||
|
def _hessian_vector_product(vector, network, criterion, base_inputs, base_targets, r=1e-2):
|
||||||
|
R = r / _concat(vector).norm()
|
||||||
|
for p, v in zip(network.weights, vector):
|
||||||
|
p.data.add_(R, v)
|
||||||
|
_, logits = network(base_inputs)
|
||||||
|
loss = criterion(logits, base_targets)
|
||||||
|
grads_p = torch.autograd.grad(loss, network.alphas)
|
||||||
|
|
||||||
|
for p, v in zip(network.weights, vector):
|
||||||
|
p.data.sub_(2*R, v)
|
||||||
|
_, logits = network(base_inputs)
|
||||||
|
loss = criterion(logits, base_targets)
|
||||||
|
grads_n = torch.autograd.grad(loss, network.alphas)
|
||||||
|
|
||||||
|
for p, v in zip(network.weights, vector):
|
||||||
|
p.data.add_(R, v)
|
||||||
|
return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)]
|
||||||
|
|
||||||
|
|
||||||
|
def backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets):
|
||||||
|
# _compute_unrolled_model
|
||||||
|
_, logits = network(base_inputs)
|
||||||
|
loss = criterion(logits, base_targets)
|
||||||
|
LR, WD, momentum = w_optimizer.param_groups[0]['lr'], w_optimizer.param_groups[0]['weight_decay'], w_optimizer.param_groups[0]['momentum']
|
||||||
|
with torch.no_grad():
|
||||||
|
theta = _concat(network.weights)
|
||||||
|
try:
|
||||||
|
moment = _concat(w_optimizer.state[v]['momentum_buffer'] for v in network.weights)
|
||||||
|
moment = moment.mul_(momentum)
|
||||||
|
except:
|
||||||
|
moment = torch.zeros_like(theta)
|
||||||
|
dtheta = _concat(torch.autograd.grad(loss, network.weights)) + WD*theta
|
||||||
|
params = theta.sub(LR, moment+dtheta)
|
||||||
|
unrolled_model = deepcopy(network)
|
||||||
|
model_dict = unrolled_model.state_dict()
|
||||||
|
new_params, offset = {}, 0
|
||||||
|
for k, v in network.named_parameters():
|
||||||
|
if 'arch_parameters' in k: continue
|
||||||
|
v_length = np.prod(v.size())
|
||||||
|
new_params[k] = params[offset: offset+v_length].view(v.size())
|
||||||
|
offset += v_length
|
||||||
|
model_dict.update(new_params)
|
||||||
|
unrolled_model.load_state_dict(model_dict)
|
||||||
|
|
||||||
|
unrolled_model.zero_grad()
|
||||||
|
_, unrolled_logits = unrolled_model(arch_inputs)
|
||||||
|
unrolled_loss = criterion(unrolled_logits, arch_targets)
|
||||||
|
unrolled_loss.backward()
|
||||||
|
|
||||||
|
dalpha = unrolled_model.arch_parameters.grad
|
||||||
|
vector = [v.grad.data for v in unrolled_model.weights]
|
||||||
|
[implicit_grads] = _hessian_vector_product(vector, network, criterion, base_inputs, base_targets)
|
||||||
|
|
||||||
|
dalpha.data.sub_(LR, implicit_grads.data)
|
||||||
|
|
||||||
|
if network.arch_parameters.grad is None:
|
||||||
|
network.arch_parameters.grad = deepcopy( dalpha )
|
||||||
|
else:
|
||||||
|
network.arch_parameters.grad.data.copy_( dalpha.data )
|
||||||
|
return unrolled_loss.detach(), unrolled_logits.detach()
|
||||||
|
|
||||||
|
|
||||||
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, algo, logger):
|
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, algo, logger):
|
||||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||||
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
@ -81,9 +148,12 @@ def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer
|
|||||||
else:
|
else:
|
||||||
raise ValueError('Invalid algo name : {:}'.format(algo))
|
raise ValueError('Invalid algo name : {:}'.format(algo))
|
||||||
network.zero_grad()
|
network.zero_grad()
|
||||||
_, logits = network(arch_inputs)
|
if algo == 'darts-v2':
|
||||||
arch_loss = criterion(logits, arch_targets)
|
arch_loss, logits = backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets)
|
||||||
arch_loss.backward()
|
else:
|
||||||
|
_, logits = network(arch_inputs)
|
||||||
|
arch_loss = criterion(logits, arch_targets)
|
||||||
|
arch_loss.backward()
|
||||||
a_optimizer.step()
|
a_optimizer.step()
|
||||||
# record
|
# record
|
||||||
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
|
||||||
@ -192,7 +262,10 @@ def main(xargs):
|
|||||||
params = count_parameters_in_MB(search_model)
|
params = count_parameters_in_MB(search_model)
|
||||||
logger.log('The parameters of the search model = {:.2f} MB'.format(params))
|
logger.log('The parameters of the search model = {:.2f} MB'.format(params))
|
||||||
logger.log('search-space : {:}'.format(search_space))
|
logger.log('search-space : {:}'.format(search_space))
|
||||||
api = API()
|
try:
|
||||||
|
api = API()
|
||||||
|
except:
|
||||||
|
api = None
|
||||||
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
||||||
|
|
||||||
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
last_info, model_base_path, model_best_path = logger.path('info'), logger.path('model'), logger.path('best')
|
||||||
@ -224,6 +297,7 @@ def main(xargs):
|
|||||||
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
|
epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
|
||||||
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
|
logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(epoch_str, need_time, min(w_scheduler.get_lr())))
|
||||||
|
|
||||||
|
network.set_drop_path(float(epoch+1) / total_epoch, xargs.drop_path_rate)
|
||||||
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
|
search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
|
||||||
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, xargs.algo, logger)
|
= search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, xargs.algo, logger)
|
||||||
search_time.update(time.time() - start_time)
|
search_time.update(time.time() - start_time)
|
||||||
@ -314,6 +388,7 @@ if __name__ == '__main__':
|
|||||||
# architecture leraning rate
|
# architecture leraning rate
|
||||||
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||||
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
|
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
|
||||||
|
parser.add_argument('--drop_path_rate' , type=float, help='The drop path rate.')
|
||||||
# log
|
# log
|
||||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||||
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
|
parser.add_argument('--save_dir', type=str, default='./output/search', help='Folder to save checkpoints and log.')
|
||||||
|
@ -67,6 +67,14 @@ class GenericNAS201Model(nn.Module):
|
|||||||
if mode == 'dynamic': self.dynamic_cell = deepcopy(dynamic_cell)
|
if mode == 'dynamic': self.dynamic_cell = deepcopy(dynamic_cell)
|
||||||
else : self.dynamic_cell = None
|
else : self.dynamic_cell = None
|
||||||
|
|
||||||
|
def set_drop_path(self, progress, drop_path_rate):
|
||||||
|
if drop_path_rate is None:
|
||||||
|
self._drop_path = None
|
||||||
|
elif progress is None:
|
||||||
|
self._drop_path = drop_path_rate
|
||||||
|
else:
|
||||||
|
self._drop_path = progress * drop_path_rate
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mode(self):
|
def mode(self):
|
||||||
return self._mode
|
return self._mode
|
||||||
@ -210,6 +218,8 @@ class GenericNAS201Model(nn.Module):
|
|||||||
feature = cell.forward_gdas(feature, alphas, index)
|
feature = cell.forward_gdas(feature, alphas, index)
|
||||||
else: raise ValueError('invalid mode={:}'.format(self.mode))
|
else: raise ValueError('invalid mode={:}'.format(self.mode))
|
||||||
else: feature = cell(feature)
|
else: feature = cell(feature)
|
||||||
|
if self.drop_path is not None:
|
||||||
|
feature = drop_path(feature, self.drop_path)
|
||||||
out = self.lastact(feature)
|
out = self.lastact(feature)
|
||||||
out = self.global_pooling(out)
|
out = self.global_pooling(out)
|
||||||
out = out.view(out.size(0), -1)
|
out = out.view(out.size(0), -1)
|
||||||
|
Loading…
Reference in New Issue
Block a user