MeCo/nasbench201/architect_ig.py
HamsterMimi 2410fe9f5e update
2023-05-04 13:42:06 +08:00

52 lines
2.0 KiB
Python

import torch
class Architect(object):
def __init__(self, model, args):
self.network_momentum = args.momentum
self.network_weight_decay = args.weight_decay
self.model = model
self.optimizer = torch.optim.Adam(self.model.arch_parameters(),
lr=args.arch_learning_rate, betas=(0.5, 0.999),
weight_decay=args.arch_weight_decay)
self._init_arch_parameters = []
for alpha in self.model.arch_parameters():
alpha_init = torch.zeros_like(alpha)
alpha_init.data.copy_(alpha)
self._init_arch_parameters.append(alpha_init)
#### mode
if args.method in ['darts', 'darts-proj', 'sdarts', 'sdarts-proj']:
self.method = 'fo' # first order update
elif 'so' in args.method:
print('ERROR: PLEASE USE architect.py for second order darts')
elif args.method in ['blank', 'blank-proj']:
self.method = 'blank'
else:
print('ERROR: WRONG ARCH UPDATE METHOD', args.method); exit(0)
def reset_arch_parameters(self):
for alpha, alpha_init in zip(self.model.arch_parameters(), self._init_arch_parameters):
alpha.data.copy_(alpha_init.data)
def step(self, input_train, target_train, input_valid, target_valid, *args, **kwargs):
if self.method == 'fo':
shared = self._step_fo(input_train, target_train, input_valid, target_valid)
elif self.method == 'so':
raise NotImplementedError
elif self.method == 'blank': ## do not update alpha
shared = None
return shared
#### first order
def _step_fo(self, input_train, target_train, input_valid, target_valid):
loss = self.model._loss(input_valid, target_valid)
loss.backward()
self.optimizer.step()
return None
#### darts 2nd order
def _step_darts_so(self, input_train, target_train, input_valid, target_valid, eta, model_optimizer):
raise NotImplementedError