174 lines
8.7 KiB
Python
174 lines
8.7 KiB
Python
###############################################################
|
|
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
|
|
###############################################################
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
|
###############################################################
|
|
from functions import evaluate_for_seed
|
|
from nas_bench_201_models import CellStructure, CellArchitectures, get_search_spaces
|
|
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
|
from nas_bench_201_datasets import get_datasets
|
|
from procedures import get_machine_info
|
|
from procedures import save_checkpoint, copy_checkpoint
|
|
from config_utils import load_config
|
|
from pathlib import Path
|
|
from copy import deepcopy
|
|
import os
|
|
import sys
|
|
import time
|
|
import torch
|
|
import random
|
|
import argparse
|
|
from PIL import ImageFile
|
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
|
|
|
|
NASBENCH201_CONFIG_PATH = os.path.join(
|
|
os.getcwd(), 'main_exp', 'transfer_nag')
|
|
|
|
|
|
def evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed,
|
|
arch_config, workers, logger):
|
|
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
|
|
all_infos = {'info': machine_info}
|
|
all_dataset_keys = []
|
|
# look all the datasets
|
|
for dataset, xpath, split in zip(datasets, xpaths, splits):
|
|
# train valid data
|
|
task = None
|
|
train_data, valid_data, xshape, class_num = get_datasets(
|
|
dataset, xpath, -1, task)
|
|
|
|
# load the configuration
|
|
if dataset in ['mnist', 'svhn', 'aircraft', 'pets']:
|
|
if use_less:
|
|
config_path = os.path.join(
|
|
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/LESS.config')
|
|
else:
|
|
config_path = os.path.join(
|
|
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{}.config'.format(dataset))
|
|
|
|
p = os.path.join(
|
|
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{:}-split.txt'.format(dataset))
|
|
if not os.path.exists(p):
|
|
import json
|
|
label_list = list(range(len(train_data)))
|
|
random.shuffle(label_list)
|
|
strlist = [str(label_list[i]) for i in range(len(label_list))]
|
|
splited = {'train': ["int", strlist[:len(train_data) // 2]],
|
|
'valid': ["int", strlist[len(train_data) // 2:]]}
|
|
with open(p, 'w') as f:
|
|
f.write(json.dumps(splited))
|
|
split_info = load_config(os.path.join(
|
|
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{:}-split.txt'.format(dataset)), None, None)
|
|
else:
|
|
raise ValueError('invalid dataset : {:}'.format(dataset))
|
|
|
|
config = load_config(
|
|
config_path, {'class_num': class_num, 'xshape': xshape}, logger)
|
|
# data loader
|
|
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size,
|
|
shuffle=True, num_workers=workers, pin_memory=True)
|
|
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
|
|
shuffle=False, num_workers=workers, pin_memory=True)
|
|
splits = load_config(os.path.join(
|
|
NASBENCH201_CONFIG_PATH, 'nas_bench_201/configs/nas-benchmark/{}-test-split.txt'.format(dataset)), None, None)
|
|
ValLoaders = {'ori-test': valid_loader,
|
|
'x-valid': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
|
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
|
splits.xvalid),
|
|
num_workers=workers, pin_memory=True),
|
|
'x-test': torch.utils.data.DataLoader(valid_data, batch_size=config.batch_size,
|
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
|
splits.xtest),
|
|
num_workers=workers, pin_memory=True)
|
|
}
|
|
dataset_key = '{:}'.format(dataset)
|
|
if bool(split):
|
|
dataset_key = dataset_key + '-valid'
|
|
logger.log(
|
|
'Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}'.
|
|
format(dataset_key, len(train_data), len(valid_data), len(train_loader), len(valid_loader), config.batch_size))
|
|
logger.log('Evaluate ||||||| {:10s} ||||||| Config={:}'.format(
|
|
dataset_key, config))
|
|
for key, value in ValLoaders.items():
|
|
logger.log(
|
|
'Evaluate ---->>>> {:10s} with {:} batchs'.format(key, len(value)))
|
|
|
|
results = evaluate_for_seed(
|
|
arch_config, config, arch, train_loader, ValLoaders, seed, logger)
|
|
all_infos[dataset_key] = results
|
|
all_dataset_keys.append(dataset_key)
|
|
all_infos['all_dataset_keys'] = all_dataset_keys
|
|
return all_infos
|
|
|
|
|
|
def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less,
|
|
seeds, model_str, arch_config):
|
|
assert torch.cuda.is_available(), 'CUDA is not available.'
|
|
torch.backends.cudnn.enabled = True
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.set_num_threads(workers)
|
|
|
|
save_dir = Path(save_dir)
|
|
logger = Logger(str(save_dir), 0, False)
|
|
|
|
if model_str in CellArchitectures:
|
|
arch = CellArchitectures[model_str]
|
|
logger.log(
|
|
'The model string is found in pre-defined architecture dict : {:}'.format(model_str))
|
|
else:
|
|
try:
|
|
arch = CellStructure.str2structure(model_str)
|
|
except:
|
|
raise ValueError(
|
|
'Invalid model string : {:}. It can not be found or parsed.'.format(model_str))
|
|
|
|
assert arch.check_valid_op(get_search_spaces(
|
|
'cell', 'nas-bench-201')), '{:} has the invalid op.'.format(arch)
|
|
# assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch)
|
|
logger.log('Start train-evaluate {:}'.format(arch.tostr()))
|
|
logger.log('arch_config : {:}'.format(arch_config))
|
|
|
|
start_time, seed_time = time.time(), AverageMeter()
|
|
for _is, seed in enumerate(seeds):
|
|
logger.log(
|
|
'\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds),
|
|
seed))
|
|
to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed)
|
|
if to_save_name.exists():
|
|
logger.log(
|
|
'Find the existing file {:}, directly load!'.format(to_save_name))
|
|
checkpoint = torch.load(to_save_name)
|
|
else:
|
|
logger.log(
|
|
'Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
|
|
checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less,
|
|
seed, arch_config, workers, logger)
|
|
torch.save(checkpoint, to_save_name)
|
|
# log information
|
|
logger.log('{:}'.format(checkpoint['info']))
|
|
all_dataset_keys = checkpoint['all_dataset_keys']
|
|
for dataset_key in all_dataset_keys:
|
|
logger.log('\n{:} dataset : {:} {:}'.format(
|
|
'-' * 15, dataset_key, '-' * 15))
|
|
dataset_info = checkpoint[dataset_key]
|
|
# logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
|
|
logger.log('Flops = {:} MB, Params = {:} MB'.format(
|
|
dataset_info['flop'], dataset_info['param']))
|
|
logger.log('config : {:}'.format(dataset_info['config']))
|
|
logger.log('Training State (finish) = {:}'.format(
|
|
dataset_info['finish-train']))
|
|
last_epoch = dataset_info['total_epoch'] - 1
|
|
train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es']
|
|
valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es']
|
|
# measure elapsed time
|
|
seed_time.update(time.time() - start_time)
|
|
start_time = time.time()
|
|
need_time = 'Time Left: {:}'.format(convert_secs2time(
|
|
seed_time.avg * (len(seeds) - _is - 1), True))
|
|
logger.log(
|
|
'\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}'.format(_is, len(seeds), seed,
|
|
need_time))
|
|
logger.close()
|