update affines for NAS
This commit is contained in:
parent
487fec21bf
commit
d175a361bd
1
.gitignore
vendored
1
.gitignore
vendored
@ -111,3 +111,4 @@ logs
|
|||||||
# snapshot
|
# snapshot
|
||||||
a.pth
|
a.pth
|
||||||
cal-merge*.sh
|
cal-merge*.sh
|
||||||
|
GPU-*.sh
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"scheduler": ["str", "cos"],
|
"scheduler": ["str", "cos"],
|
||||||
"eta_min" : ["float", "0.0"],
|
"eta_min" : ["float", "0.0"],
|
||||||
"epochs" : ["int", "10"],
|
"epochs" : ["int", "12"],
|
||||||
"warmup" : ["int", "0"],
|
"warmup" : ["int", "0"],
|
||||||
"optim" : ["str", "SGD"],
|
"optim" : ["str", "SGD"],
|
||||||
"LR" : ["float", "0.1"],
|
"LR" : ["float", "0.1"],
|
||||||
|
@ -15,10 +15,10 @@ from procedures import get_machine_info
|
|||||||
from datasets import get_datasets
|
from datasets import get_datasets
|
||||||
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
||||||
from models import CellStructure, CellArchitectures, get_search_spaces
|
from models import CellStructure, CellArchitectures, get_search_spaces
|
||||||
from AA_functions import evaluate_for_seed
|
from AA_functions_v2 import evaluate_for_seed
|
||||||
|
|
||||||
|
|
||||||
def evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, workers, logger):
|
def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger):
|
||||||
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
|
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
|
||||||
all_infos = {'info': machine_info}
|
all_infos = {'info': machine_info}
|
||||||
all_dataset_keys = []
|
all_dataset_keys = []
|
||||||
@ -28,10 +28,12 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, wor
|
|||||||
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
|
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
|
||||||
# load the configurature
|
# load the configurature
|
||||||
if dataset == 'cifar10' or dataset == 'cifar100':
|
if dataset == 'cifar10' or dataset == 'cifar100':
|
||||||
config_path = 'configs/nas-benchmark/CIFAR.config'
|
if use_less: config_path = 'configs/nas-benchmark/LESS.config'
|
||||||
|
else : config_path = 'configs/nas-benchmark/CIFAR.config'
|
||||||
split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
|
split_info = load_config('configs/nas-benchmark/cifar-split.txt', None, None)
|
||||||
elif dataset.startswith('ImageNet16'):
|
elif dataset.startswith('ImageNet16'):
|
||||||
config_path = 'configs/nas-benchmark/ImageNet-16.config'
|
if use_less: config_path = 'configs/nas-benchmark/LESS.config'
|
||||||
|
else : config_path = 'configs/nas-benchmark/ImageNet-16.config'
|
||||||
split_info = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None)
|
split_info = load_config('configs/nas-benchmark/{:}-split.txt'.format(dataset), None, None)
|
||||||
else:
|
else:
|
||||||
raise ValueError('invalid dataset : {:}'.format(dataset))
|
raise ValueError('invalid dataset : {:}'.format(dataset))
|
||||||
@ -41,6 +43,8 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, wor
|
|||||||
logger)
|
logger)
|
||||||
# check whether use splited validation set
|
# check whether use splited validation set
|
||||||
if bool(split):
|
if bool(split):
|
||||||
|
assert dataset == 'cifar10'
|
||||||
|
ValLoaders = {'ori-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)}
|
||||||
assert len(train_data) == len(split_info.train) + len(split_info.valid), 'invalid length : {:} vs {:} + {:}'.format(len(train_data), len(split_info.train), len(split_info.valid))
|
assert len(train_data) == len(split_info.train) + len(split_info.valid), 'invalid length : {:} vs {:} + {:}'.format(len(train_data), len(split_info.train), len(split_info.valid))
|
||||||
train_data_v2 = deepcopy(train_data)
|
train_data_v2 = deepcopy(train_data)
|
||||||
train_data_v2.transform = valid_data.transform
|
train_data_v2.transform = valid_data.transform
|
||||||
@ -48,23 +52,42 @@ def evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, wor
|
|||||||
# data loader
|
# data loader
|
||||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True)
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train), num_workers=workers, pin_memory=True)
|
||||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True)
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid), num_workers=workers, pin_memory=True)
|
||||||
|
ValLoaders['x-valid'] = valid_loader
|
||||||
else:
|
else:
|
||||||
# data loader
|
# data loader
|
||||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True , num_workers=workers, pin_memory=True)
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, shuffle=True , num_workers=workers, pin_memory=True)
|
||||||
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, shuffle=False, num_workers=workers, pin_memory=True)
|
||||||
|
if dataset == 'cifar10':
|
||||||
|
ValLoaders = {'ori-test': valid_loader}
|
||||||
|
elif dataset == 'cifar100':
|
||||||
|
cifar100_splits = load_config('configs/nas-benchmark/cifar100-test-split.txt', None, None)
|
||||||
|
ValLoaders = {'ori-test': valid_loader,
|
||||||
|
'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xvalid), num_workers=workers, pin_memory=True),
|
||||||
|
'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_splits.xtest ), num_workers=workers, pin_memory=True)
|
||||||
|
}
|
||||||
|
elif dataset == 'ImageNet16-120':
|
||||||
|
imagenet16_splits = load_config('configs/nas-benchmark/imagenet-16-120-test-split.txt', None, None)
|
||||||
|
ValLoaders = {'ori-test': valid_loader,
|
||||||
|
'x-valid' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xvalid), num_workers=workers, pin_memory=True),
|
||||||
|
'x-test' : torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet16_splits.xtest ), num_workers=workers, pin_memory=True)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid dataset : {:}'.format(dataset))
|
||||||
|
|
||||||
dataset_key = '{:}'.format(dataset)
|
dataset_key = '{:}'.format(dataset)
|
||||||
if bool(split): dataset_key = dataset_key + '-valid'
|
if bool(split): dataset_key = dataset_key + '-valid'
|
||||||
logger.log('Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size))
|
logger.log('Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size))
|
||||||
logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config))
|
logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(dataset_key, config))
|
||||||
results = evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, seed, logger)
|
for key, value in ValLoaders.items():
|
||||||
|
logger.log('Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value)))
|
||||||
|
results = evaluate_for_seed(arch_config, config, arch, train_loader, ValLoaders, seed, logger)
|
||||||
all_infos[dataset_key] = results
|
all_infos[dataset_key] = results
|
||||||
all_dataset_keys.append( dataset_key )
|
all_dataset_keys.append( dataset_key )
|
||||||
all_infos['all_dataset_keys'] = all_dataset_keys
|
all_infos['all_dataset_keys'] = all_dataset_keys
|
||||||
return all_infos
|
return all_infos
|
||||||
|
|
||||||
|
|
||||||
def main(save_dir, workers, datasets, xpaths, splits, srange, arch_index, seeds, cover_mode, meta_info, arch_config):
|
def main(save_dir, workers, datasets, xpaths, splits, use_less, srange, arch_index, seeds, cover_mode, meta_info, arch_config):
|
||||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||||
torch.backends.cudnn.enabled = True
|
torch.backends.cudnn.enabled = True
|
||||||
#torch.backends.cudnn.benchmark = True
|
#torch.backends.cudnn.benchmark = True
|
||||||
@ -73,7 +96,10 @@ def main(save_dir, workers, datasets, xpaths, splits, srange, arch_index, seeds,
|
|||||||
|
|
||||||
assert len(srange) == 2 and 0 <= srange[0] <= srange[1], 'invalid srange : {:}'.format(srange)
|
assert len(srange) == 2 and 0 <= srange[0] <= srange[1], 'invalid srange : {:}'.format(srange)
|
||||||
|
|
||||||
sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells'])
|
if use_less:
|
||||||
|
sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}-LESS'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells'])
|
||||||
|
else:
|
||||||
|
sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells'])
|
||||||
logger = Logger(str(sub_dir), 0, False)
|
logger = Logger(str(sub_dir), 0, False)
|
||||||
|
|
||||||
all_archs = meta_info['archs']
|
all_archs = meta_info['archs']
|
||||||
@ -114,7 +140,7 @@ def main(save_dir, workers, datasets, xpaths, splits, srange, arch_index, seeds,
|
|||||||
has_continue = True
|
has_continue = True
|
||||||
continue
|
continue
|
||||||
results = evaluate_all_datasets(CellStructure.str2structure(arch), \
|
results = evaluate_all_datasets(CellStructure.str2structure(arch), \
|
||||||
datasets, xpaths, splits, seed, \
|
datasets, xpaths, splits, use_less, seed, \
|
||||||
arch_config, workers, logger)
|
arch_config, workers, logger)
|
||||||
torch.save(results, to_save_name)
|
torch.save(results, to_save_name)
|
||||||
logger.log('{:} --evaluate-- {:06d}/{:06d} ({:06d}/{:06d})-th seed={:} done, save into {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seed, to_save_name))
|
logger.log('{:} --evaluate-- {:06d}/{:06d} ({:06d}/{:06d})-th seed={:} done, save into {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seed, to_save_name))
|
||||||
@ -130,7 +156,7 @@ def main(save_dir, workers, datasets, xpaths, splits, srange, arch_index, seeds,
|
|||||||
logger.close()
|
logger.close()
|
||||||
|
|
||||||
|
|
||||||
def train_single_model(save_dir, workers, datasets, xpaths, splits, seeds, model_str, arch_config):
|
def train_single_model(save_dir, workers, datasets, xpaths, use_less, splits, seeds, model_str, arch_config):
|
||||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||||
torch.backends.cudnn.enabled = True
|
torch.backends.cudnn.enabled = True
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
@ -160,7 +186,7 @@ def train_single_model(save_dir, workers, datasets, xpaths, splits, seeds, model
|
|||||||
checkpoint = torch.load(to_save_name)
|
checkpoint = torch.load(to_save_name)
|
||||||
else:
|
else:
|
||||||
logger.log('Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
|
logger.log('Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
|
||||||
checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, seed, arch_config, workers, logger)
|
checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger)
|
||||||
torch.save(checkpoint, to_save_name)
|
torch.save(checkpoint, to_save_name)
|
||||||
# log information
|
# log information
|
||||||
logger.log('{:}'.format(checkpoint['info']))
|
logger.log('{:}'.format(checkpoint['info']))
|
||||||
@ -252,6 +278,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--datasets', type=str, nargs='+', help='The applied datasets.')
|
parser.add_argument('--datasets', type=str, nargs='+', help='The applied datasets.')
|
||||||
parser.add_argument('--xpaths', type=str, nargs='+', help='The root path for this dataset.')
|
parser.add_argument('--xpaths', type=str, nargs='+', help='The root path for this dataset.')
|
||||||
parser.add_argument('--splits', type=int, nargs='+', help='The root path for this dataset.')
|
parser.add_argument('--splits', type=int, nargs='+', help='The root path for this dataset.')
|
||||||
|
parser.add_argument('--use_less', type=int, default=0, help='Using the less-training-epoch config.')
|
||||||
parser.add_argument('--seeds' , type=int, nargs='+', help='The range of models to be evaluated')
|
parser.add_argument('--seeds' , type=int, nargs='+', help='The range of models to be evaluated')
|
||||||
parser.add_argument('--channel', type=int, help='The number of channels.')
|
parser.add_argument('--channel', type=int, help='The number of channels.')
|
||||||
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
parser.add_argument('--num_cells', type=int, help='The number of cells in one stage.')
|
||||||
@ -264,7 +291,7 @@ if __name__ == '__main__':
|
|||||||
elif args.mode.startswith('specific'):
|
elif args.mode.startswith('specific'):
|
||||||
assert len(args.mode.split('-')) == 2, 'invalid mode : {:}'.format(args.mode)
|
assert len(args.mode.split('-')) == 2, 'invalid mode : {:}'.format(args.mode)
|
||||||
model_str = args.mode.split('-')[1]
|
model_str = args.mode.split('-')[1]
|
||||||
train_single_model(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, \
|
train_single_model(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.use_less>0, \
|
||||||
tuple(args.seeds), model_str, {'channel': args.channel, 'num_cells': args.num_cells})
|
tuple(args.seeds), model_str, {'channel': args.channel, 'num_cells': args.num_cells})
|
||||||
else:
|
else:
|
||||||
meta_path = Path(args.save_dir) / 'meta-node-{:}.pth'.format(args.max_node)
|
meta_path = Path(args.save_dir) / 'meta-node-{:}.pth'.format(args.max_node)
|
||||||
@ -276,7 +303,7 @@ if __name__ == '__main__':
|
|||||||
assert len(args.datasets) == len(args.xpaths) == len(args.splits), 'invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits))
|
assert len(args.datasets) == len(args.xpaths) == len(args.splits), 'invalid infos : {:} vs {:} vs {:}'.format(len(args.datasets), len(args.xpaths), len(args.splits))
|
||||||
assert args.workers > 0, 'invalid number of workers : {:}'.format(args.workers)
|
assert args.workers > 0, 'invalid number of workers : {:}'.format(args.workers)
|
||||||
|
|
||||||
main(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, \
|
main(args.save_dir, args.workers, args.datasets, args.xpaths, args.splits, args.use_less>0, \
|
||||||
tuple(args.srange), args.arch_index, tuple(args.seeds), \
|
tuple(args.srange), args.arch_index, tuple(args.seeds), \
|
||||||
args.mode == 'cover', meta_info, \
|
args.mode == 'cover', meta_info, \
|
||||||
{'channel': args.channel, 'num_cells': args.num_cells})
|
{'channel': args.channel, 'num_cells': args.num_cells})
|
||||||
|
@ -47,6 +47,7 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode):
|
|||||||
elif mode == 'valid': network.eval()
|
elif mode == 'valid': network.eval()
|
||||||
else: raise ValueError("The mode is not right : {:}".format(mode))
|
else: raise ValueError("The mode is not right : {:}".format(mode))
|
||||||
|
|
||||||
|
batch_time, end = AverageMeter(), time.time()
|
||||||
for i, (inputs, targets) in enumerate(xloader):
|
for i, (inputs, targets) in enumerate(xloader):
|
||||||
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
|
if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
|
||||||
|
|
||||||
@ -64,7 +65,10 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode):
|
|||||||
losses.update(loss.item(), inputs.size(0))
|
losses.update(loss.item(), inputs.size(0))
|
||||||
top1.update (prec1.item(), inputs.size(0))
|
top1.update (prec1.item(), inputs.size(0))
|
||||||
top5.update (prec5.item(), inputs.size(0))
|
top5.update (prec5.item(), inputs.size(0))
|
||||||
return losses.avg, top1.avg, top5.avg
|
# count time
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -87,18 +91,21 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, see
|
|||||||
# start training
|
# start training
|
||||||
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
|
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
|
||||||
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
|
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
|
||||||
|
train_times , valid_times = {}, {}
|
||||||
for epoch in range(total_epoch):
|
for epoch in range(total_epoch):
|
||||||
scheduler.update(epoch, 0.0)
|
scheduler.update(epoch, 0.0)
|
||||||
|
|
||||||
train_loss, train_acc1, train_acc5 = procedure(train_loader, network, criterion, scheduler, optimizer, 'train')
|
train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train')
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
valid_loss, valid_acc1, valid_acc5 = procedure(valid_loader, network, criterion, None, None, 'valid')
|
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(valid_loader, network, criterion, None, None, 'valid')
|
||||||
train_losses[epoch] = train_loss
|
train_losses[epoch] = train_loss
|
||||||
train_acc1es[epoch] = train_acc1
|
train_acc1es[epoch] = train_acc1
|
||||||
train_acc5es[epoch] = train_acc5
|
train_acc5es[epoch] = train_acc5
|
||||||
valid_losses[epoch] = valid_loss
|
valid_losses[epoch] = valid_loss
|
||||||
valid_acc1es[epoch] = valid_acc1
|
valid_acc1es[epoch] = valid_acc1
|
||||||
valid_acc5es[epoch] = valid_acc5
|
valid_acc5es[epoch] = valid_acc5
|
||||||
|
train_times [epoch] = train_tm
|
||||||
|
valid_times [epoch] = valid_tm
|
||||||
|
|
||||||
# measure elapsed time
|
# measure elapsed time
|
||||||
epoch_time.update(time.time() - start_time)
|
epoch_time.update(time.time() - start_time)
|
||||||
@ -114,9 +121,11 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loader, see
|
|||||||
'train_losses': train_losses,
|
'train_losses': train_losses,
|
||||||
'train_acc1es': train_acc1es,
|
'train_acc1es': train_acc1es,
|
||||||
'train_acc5es': train_acc5es,
|
'train_acc5es': train_acc5es,
|
||||||
|
'train_times' : train_times,
|
||||||
'valid_losses': valid_losses,
|
'valid_losses': valid_losses,
|
||||||
'valid_acc1es': valid_acc1es,
|
'valid_acc1es': valid_acc1es,
|
||||||
'valid_acc5es': valid_acc5es,
|
'valid_acc5es': valid_acc5es,
|
||||||
|
'valid_times' : valid_times,
|
||||||
'net_state_dict': net.state_dict(),
|
'net_state_dict': net.state_dict(),
|
||||||
'net_string' : '{:}'.format(net),
|
'net_string' : '{:}'.format(net),
|
||||||
'finish-train': True
|
'finish-train': True
|
||||||
|
@ -19,9 +19,9 @@ class InferCell(nn.Module):
|
|||||||
cur_innod = []
|
cur_innod = []
|
||||||
for (op_name, op_in) in node_info:
|
for (op_name, op_in) in node_info:
|
||||||
if op_in == 0:
|
if op_in == 0:
|
||||||
layer = OPS[op_name](C_in , C_out, stride)
|
layer = OPS[op_name](C_in , C_out, stride, True)
|
||||||
else:
|
else:
|
||||||
layer = OPS[op_name](C_out, C_out, 1)
|
layer = OPS[op_name](C_out, C_out, 1, True)
|
||||||
cur_index.append( len(self.layers) )
|
cur_index.append( len(self.layers) )
|
||||||
cur_innod.append( op_in )
|
cur_innod.append( op_in )
|
||||||
self.layers.append( layer )
|
self.layers.append( layer )
|
||||||
|
@ -22,7 +22,7 @@ class TinyNetwork(nn.Module):
|
|||||||
self.cells = nn.ModuleList()
|
self.cells = nn.ModuleList()
|
||||||
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||||
if reduction:
|
if reduction:
|
||||||
cell = ResNetBasicblock(C_prev, C_curr, 2)
|
cell = ResNetBasicblock(C_prev, C_curr, 2, True)
|
||||||
else:
|
else:
|
||||||
cell = InferCell(genotype, C_prev, C_curr, 1)
|
cell = InferCell(genotype, C_prev, C_curr, 1)
|
||||||
self.cells.append( cell )
|
self.cells.append( cell )
|
||||||
|
@ -4,16 +4,16 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
__all__ = ['OPS', 'ReLUConvBN', 'ResNetBasicblock', 'SearchSpaceNames']
|
__all__ = ['OPS', 'ResNetBasicblock', 'SearchSpaceNames']
|
||||||
|
|
||||||
OPS = {
|
OPS = {
|
||||||
'none' : lambda C_in, C_out, stride: Zero(C_in, C_out, stride),
|
'none' : lambda C_in, C_out, stride, affine: Zero(C_in, C_out, stride),
|
||||||
'avg_pool_3x3' : lambda C_in, C_out, stride: POOLING(C_in, C_out, stride, 'avg'),
|
'avg_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'avg'),
|
||||||
'max_pool_3x3' : lambda C_in, C_out, stride: POOLING(C_in, C_out, stride, 'max'),
|
'max_pool_3x3' : lambda C_in, C_out, stride, affine: POOLING(C_in, C_out, stride, 'max'),
|
||||||
'nor_conv_7x7' : lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1)),
|
'nor_conv_7x7' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (7,7), (stride,stride), (3,3), (1,1), affine),
|
||||||
'nor_conv_3x3' : lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1)),
|
'nor_conv_3x3' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (3,3), (stride,stride), (1,1), (1,1), affine),
|
||||||
'nor_conv_1x1' : lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1)),
|
'nor_conv_1x1' : lambda C_in, C_out, stride, affine: ReLUConvBN(C_in, C_out, (1,1), (stride,stride), (0,0), (1,1), affine),
|
||||||
'skip_connect' : lambda C_in, C_out, stride: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride),
|
'skip_connect' : lambda C_in, C_out, stride, affine: Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride, affine),
|
||||||
}
|
}
|
||||||
|
|
||||||
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
|
CONNECT_NAS_BENCHMARK = ['none', 'skip_connect', 'nor_conv_3x3']
|
||||||
@ -26,12 +26,12 @@ SearchSpaceNames = {'connect-nas' : CONNECT_NAS_BENCHMARK,
|
|||||||
|
|
||||||
class ReLUConvBN(nn.Module):
|
class ReLUConvBN(nn.Module):
|
||||||
|
|
||||||
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
|
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine):
|
||||||
super(ReLUConvBN, self).__init__()
|
super(ReLUConvBN, self).__init__()
|
||||||
self.op = nn.Sequential(
|
self.op = nn.Sequential(
|
||||||
nn.ReLU(inplace=False),
|
nn.ReLU(inplace=False),
|
||||||
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False),
|
nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False),
|
||||||
nn.BatchNorm2d(C_out)
|
nn.BatchNorm2d(C_out, affine=affine)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -40,17 +40,17 @@ class ReLUConvBN(nn.Module):
|
|||||||
|
|
||||||
class ResNetBasicblock(nn.Module):
|
class ResNetBasicblock(nn.Module):
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride):
|
def __init__(self, inplanes, planes, stride, affine=True):
|
||||||
super(ResNetBasicblock, self).__init__()
|
super(ResNetBasicblock, self).__init__()
|
||||||
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
|
||||||
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1)
|
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, affine)
|
||||||
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1)
|
self.conv_b = ReLUConvBN( planes, planes, 3, 1, 1, 1, affine)
|
||||||
if stride == 2:
|
if stride == 2:
|
||||||
self.downsample = nn.Sequential(
|
self.downsample = nn.Sequential(
|
||||||
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
|
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
|
||||||
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
|
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
|
||||||
elif inplanes != planes:
|
elif inplanes != planes:
|
||||||
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1)
|
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, affine)
|
||||||
else:
|
else:
|
||||||
self.downsample = None
|
self.downsample = None
|
||||||
self.in_dim = inplanes
|
self.in_dim = inplanes
|
||||||
@ -76,12 +76,12 @@ class ResNetBasicblock(nn.Module):
|
|||||||
|
|
||||||
class POOLING(nn.Module):
|
class POOLING(nn.Module):
|
||||||
|
|
||||||
def __init__(self, C_in, C_out, stride, mode):
|
def __init__(self, C_in, C_out, stride, mode, affine=True):
|
||||||
super(POOLING, self).__init__()
|
super(POOLING, self).__init__()
|
||||||
if C_in == C_out:
|
if C_in == C_out:
|
||||||
self.preprocess = None
|
self.preprocess = None
|
||||||
else:
|
else:
|
||||||
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0)
|
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, affine)
|
||||||
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
|
if mode == 'avg' : self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
|
||||||
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
|
elif mode == 'max': self.op = nn.MaxPool2d(3, stride=stride, padding=1)
|
||||||
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
|
else : raise ValueError('Invalid mode={:} in POOLING'.format(mode))
|
||||||
@ -126,7 +126,7 @@ class Zero(nn.Module):
|
|||||||
|
|
||||||
class FactorizedReduce(nn.Module):
|
class FactorizedReduce(nn.Module):
|
||||||
|
|
||||||
def __init__(self, C_in, C_out, stride):
|
def __init__(self, C_in, C_out, stride, affine):
|
||||||
super(FactorizedReduce, self).__init__()
|
super(FactorizedReduce, self).__init__()
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.C_in = C_in
|
self.C_in = C_in
|
||||||
@ -141,8 +141,7 @@ class FactorizedReduce(nn.Module):
|
|||||||
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Invalid stride : {:}'.format(stride))
|
raise ValueError('Invalid stride : {:}'.format(stride))
|
||||||
|
self.bn = nn.BatchNorm2d(C_out, affine=affine)
|
||||||
self.bn = nn.BatchNorm2d(C_out)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.relu(x)
|
x = self.relu(x)
|
||||||
|
@ -23,9 +23,9 @@ class SearchCell(nn.Module):
|
|||||||
for j in range(i):
|
for j in range(i):
|
||||||
node_str = '{:}<-{:}'.format(i, j)
|
node_str = '{:}<-{:}'.format(i, j)
|
||||||
if j == 0:
|
if j == 0:
|
||||||
xlists = [OPS[op_name](C_in , C_out, stride) for op_name in op_names]
|
xlists = [OPS[op_name](C_in , C_out, stride, False) for op_name in op_names]
|
||||||
else:
|
else:
|
||||||
xlists = [OPS[op_name](C_in , C_out, 1) for op_name in op_names]
|
xlists = [OPS[op_name](C_in , C_out, 1, False) for op_name in op_names]
|
||||||
self.edges[ node_str ] = nn.ModuleList( xlists )
|
self.edges[ node_str ] = nn.ModuleList( xlists )
|
||||||
self.edge_keys = sorted(list(self.edges.keys()))
|
self.edge_keys = sorted(list(self.edges.keys()))
|
||||||
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
|
self.edge2index = {key:i for i, key in enumerate(self.edge_keys)}
|
||||||
|
@ -29,6 +29,7 @@ save_dir=./output/AA-NAS-BENCH-4/
|
|||||||
|
|
||||||
OMP_NUM_THREADS=4 python ./exps/AA-NAS-Bench-main.py \
|
OMP_NUM_THREADS=4 python ./exps/AA-NAS-Bench-main.py \
|
||||||
--mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
--mode ${mode} --save_dir ${save_dir} --max_node 4 \
|
||||||
|
--use_less 0 \
|
||||||
--datasets cifar10 cifar10 cifar100 ImageNet16-120 \
|
--datasets cifar10 cifar10 cifar100 ImageNet16-120 \
|
||||||
--splits 1 0 0 0 \
|
--splits 1 0 0 0 \
|
||||||
--xpaths $TORCH_HOME/cifar.python \
|
--xpaths $TORCH_HOME/cifar.python \
|
||||||
|
Loading…
Reference in New Issue
Block a user