52 lines
2.0 KiB
Python
52 lines
2.0 KiB
Python
import torch
|
|
|
|
|
|
class Architect(object):
|
|
def __init__(self, model, args):
|
|
self.network_momentum = args.momentum
|
|
self.network_weight_decay = args.weight_decay
|
|
self.model = model
|
|
self.optimizer = torch.optim.Adam(self.model.arch_parameters(),
|
|
lr=args.arch_learning_rate, betas=(0.5, 0.999),
|
|
weight_decay=args.arch_weight_decay)
|
|
|
|
self._init_arch_parameters = []
|
|
for alpha in self.model.arch_parameters():
|
|
alpha_init = torch.zeros_like(alpha)
|
|
alpha_init.data.copy_(alpha)
|
|
self._init_arch_parameters.append(alpha_init)
|
|
|
|
#### mode
|
|
if args.method in ['darts', 'darts-proj', 'sdarts', 'sdarts-proj']:
|
|
self.method = 'fo' # first order update
|
|
elif 'so' in args.method:
|
|
print('ERROR: PLEASE USE architect.py for second order darts')
|
|
elif args.method in ['blank', 'blank-proj']:
|
|
self.method = 'blank'
|
|
else:
|
|
print('ERROR: WRONG ARCH UPDATE METHOD', args.method); exit(0)
|
|
|
|
def reset_arch_parameters(self):
|
|
for alpha, alpha_init in zip(self.model.arch_parameters(), self._init_arch_parameters):
|
|
alpha.data.copy_(alpha_init.data)
|
|
|
|
def step(self, input_train, target_train, input_valid, target_valid, *args, **kwargs):
|
|
if self.method == 'fo':
|
|
shared = self._step_fo(input_train, target_train, input_valid, target_valid)
|
|
elif self.method == 'so':
|
|
raise NotImplementedError
|
|
elif self.method == 'blank': ## do not update alpha
|
|
shared = None
|
|
|
|
return shared
|
|
|
|
#### first order
|
|
def _step_fo(self, input_train, target_train, input_valid, target_valid):
|
|
loss = self.model._loss(input_valid, target_valid)
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
return None
|
|
|
|
#### darts 2nd order
|
|
def _step_darts_so(self, input_train, target_train, input_valid, target_valid, eta, model_optimizer):
|
|
raise NotImplementedError |