616 lines
23 KiB
Python
616 lines
23 KiB
Python
|
from nas_201_api import NASBench201API as API
|
||
|
import os
|
||
|
|
||
|
import os, sys, time, torch, random, argparse
|
||
|
from PIL import ImageFile
|
||
|
|
||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||
|
from copy import deepcopy
|
||
|
from pathlib import Path
|
||
|
|
||
|
from xautodl.config_utils import load_config
|
||
|
from xautodl.procedures import save_checkpoint, copy_checkpoint
|
||
|
from xautodl.procedures import get_machine_info
|
||
|
from xautodl.datasets import get_datasets
|
||
|
from xautodl.log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
||
|
from xautodl.models import CellStructure, CellArchitectures, get_search_spaces
|
||
|
|
||
|
import time, torch
|
||
|
from xautodl.procedures import prepare_seed, get_optim_scheduler
|
||
|
from xautodl.utils import get_model_infos, obtain_accuracy
|
||
|
from xautodl.config_utils import dict2config
|
||
|
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||
|
from xautodl.models import get_cell_based_tiny_net
|
||
|
|
||
|
cur_path = os.path.abspath(os.path.curdir)
|
||
|
data_path = os.path.join(cur_path, 'NAS-Bench-201-v1_1-096897.pth')
|
||
|
print(f'loading data from {data_path}')
|
||
|
print(f'loading')
|
||
|
api = API(data_path)
|
||
|
print(f'loaded')
|
||
|
|
||
|
def find_best_index(dataset):
|
||
|
len = 15625
|
||
|
accs = []
|
||
|
for i in range(1, len):
|
||
|
results = api.query_by_index(i, dataset)
|
||
|
dict_items = list(results.items())
|
||
|
train_info = dict_items[0][1].get_train()
|
||
|
acc = train_info['accuracy']
|
||
|
accs.append((i, acc))
|
||
|
return max(accs, key=lambda x: x[1])
|
||
|
|
||
|
best_cifar_10_index, best_cifar_10_acc = find_best_index('cifar10')
|
||
|
best_cifar_100_index, best_cifar_100_acc = find_best_index('cifar100')
|
||
|
best_ImageNet16_index, best_ImageNet16_acc= find_best_index('ImageNet16-120')
|
||
|
print(f'find best cifar10 index: {best_cifar_10_index}, acc: {best_cifar_10_acc}')
|
||
|
print(f'find best cifar100 index: {best_cifar_100_index}, acc: {best_cifar_100_acc}')
|
||
|
print(f'find best ImageNet16 index: {best_ImageNet16_index}, acc: {best_ImageNet16_acc}')
|
||
|
|
||
|
from xautodl.models import get_cell_based_tiny_net
|
||
|
def get_network_str_by_id(id, dataset):
|
||
|
config = api.get_net_config(id, dataset)
|
||
|
return config['arch_str']
|
||
|
|
||
|
best_cifar_10_str = get_network_str_by_id(best_cifar_10_index, 'cifar10')
|
||
|
best_cifar_100_str = get_network_str_by_id(best_cifar_100_index, 'cifar100')
|
||
|
best_ImageNet16_str = get_network_str_by_id(best_ImageNet16_index, 'ImageNet16-120')
|
||
|
|
||
|
def evaluate_all_datasets(
|
||
|
arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger
|
||
|
):
|
||
|
machine_info, arch_config = get_machine_info(), deepcopy(arch_config)
|
||
|
all_infos = {"info": machine_info}
|
||
|
all_dataset_keys = []
|
||
|
# look all the datasets
|
||
|
for dataset, xpath, split in zip(datasets, xpaths, splits):
|
||
|
# train valid data
|
||
|
train_data, valid_data, xshape, class_num = get_datasets(dataset, xpath, -1)
|
||
|
# load the configuration
|
||
|
if dataset == "cifar10" or dataset == "cifar100":
|
||
|
if use_less:
|
||
|
config_path = "configs/nas-benchmark/LESS.config"
|
||
|
else:
|
||
|
config_path = "configs/nas-benchmark/CIFAR.config"
|
||
|
split_info = load_config(
|
||
|
"configs/nas-benchmark/cifar-split.txt", None, None
|
||
|
)
|
||
|
elif dataset.startswith("ImageNet16"):
|
||
|
if use_less:
|
||
|
config_path = "configs/nas-benchmark/LESS.config"
|
||
|
else:
|
||
|
config_path = "configs/nas-benchmark/ImageNet-16.config"
|
||
|
split_info = load_config(
|
||
|
"configs/nas-benchmark/{:}-split.txt".format(dataset), None, None
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError("invalid dataset : {:}".format(dataset))
|
||
|
config = load_config(
|
||
|
config_path, {"class_num": class_num, "xshape": xshape}, logger
|
||
|
)
|
||
|
# check whether use splited validation set
|
||
|
if bool(split):
|
||
|
assert dataset == "cifar10"
|
||
|
ValLoaders = {
|
||
|
"ori-test": torch.utils.data.DataLoader(
|
||
|
valid_data,
|
||
|
batch_size=config.batch_size,
|
||
|
shuffle=False,
|
||
|
num_workers=workers,
|
||
|
pin_memory=True,
|
||
|
)
|
||
|
}
|
||
|
assert len(train_data) == len(split_info.train) + len(
|
||
|
split_info.valid
|
||
|
), "invalid length : {:} vs {:} + {:}".format(
|
||
|
len(train_data), len(split_info.train), len(split_info.valid)
|
||
|
)
|
||
|
train_data_v2 = deepcopy(train_data)
|
||
|
train_data_v2.transform = valid_data.transform
|
||
|
valid_data = train_data_v2
|
||
|
# data loader
|
||
|
train_loader = torch.utils.data.DataLoader(
|
||
|
train_data,
|
||
|
batch_size=config.batch_size,
|
||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.train),
|
||
|
num_workers=workers,
|
||
|
pin_memory=True,
|
||
|
)
|
||
|
valid_loader = torch.utils.data.DataLoader(
|
||
|
valid_data,
|
||
|
batch_size=config.batch_size,
|
||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(split_info.valid),
|
||
|
num_workers=workers,
|
||
|
pin_memory=True,
|
||
|
)
|
||
|
ValLoaders["x-valid"] = valid_loader
|
||
|
else:
|
||
|
# data loader
|
||
|
train_loader = torch.utils.data.DataLoader(
|
||
|
train_data,
|
||
|
batch_size=config.batch_size,
|
||
|
shuffle=True,
|
||
|
num_workers=workers,
|
||
|
pin_memory=True,
|
||
|
)
|
||
|
valid_loader = torch.utils.data.DataLoader(
|
||
|
valid_data,
|
||
|
batch_size=config.batch_size,
|
||
|
shuffle=False,
|
||
|
num_workers=workers,
|
||
|
pin_memory=True,
|
||
|
)
|
||
|
if dataset == "cifar10":
|
||
|
ValLoaders = {"ori-test": valid_loader}
|
||
|
elif dataset == "cifar100":
|
||
|
cifar100_splits = load_config(
|
||
|
"configs/nas-benchmark/cifar100-test-split.txt", None, None
|
||
|
)
|
||
|
ValLoaders = {
|
||
|
"ori-test": valid_loader,
|
||
|
"x-valid": torch.utils.data.DataLoader(
|
||
|
valid_data,
|
||
|
batch_size=config.batch_size,
|
||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||
|
cifar100_splits.xvalid
|
||
|
),
|
||
|
num_workers=workers,
|
||
|
pin_memory=True,
|
||
|
),
|
||
|
"x-test": torch.utils.data.DataLoader(
|
||
|
valid_data,
|
||
|
batch_size=config.batch_size,
|
||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||
|
cifar100_splits.xtest
|
||
|
),
|
||
|
num_workers=workers,
|
||
|
pin_memory=True,
|
||
|
),
|
||
|
}
|
||
|
elif dataset == "ImageNet16-120":
|
||
|
imagenet16_splits = load_config(
|
||
|
"configs/nas-benchmark/imagenet-16-120-test-split.txt", None, None
|
||
|
)
|
||
|
ValLoaders = {
|
||
|
"ori-test": valid_loader,
|
||
|
"x-valid": torch.utils.data.DataLoader(
|
||
|
valid_data,
|
||
|
batch_size=config.batch_size,
|
||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||
|
imagenet16_splits.xvalid
|
||
|
),
|
||
|
num_workers=workers,
|
||
|
pin_memory=True,
|
||
|
),
|
||
|
"x-test": torch.utils.data.DataLoader(
|
||
|
valid_data,
|
||
|
batch_size=config.batch_size,
|
||
|
sampler=torch.utils.data.sampler.SubsetRandomSampler(
|
||
|
imagenet16_splits.xtest
|
||
|
),
|
||
|
num_workers=workers,
|
||
|
pin_memory=True,
|
||
|
),
|
||
|
}
|
||
|
else:
|
||
|
raise ValueError("invalid dataset : {:}".format(dataset))
|
||
|
|
||
|
dataset_key = "{:}".format(dataset)
|
||
|
if bool(split):
|
||
|
dataset_key = dataset_key + "-valid"
|
||
|
logger.log(
|
||
|
"Evaluate ||||||| {:10s} ||||||| Train-Num={:}, Valid-Num={:}, Train-Loader-Num={:}, Valid-Loader-Num={:}, batch size={:}".format(
|
||
|
dataset_key,
|
||
|
len(train_data),
|
||
|
len(valid_data),
|
||
|
len(train_loader),
|
||
|
len(valid_loader),
|
||
|
config.batch_size,
|
||
|
)
|
||
|
)
|
||
|
logger.log(
|
||
|
"Evaluate ||||||| {:10s} ||||||| Config={:}".format(dataset_key, config)
|
||
|
)
|
||
|
for key, value in ValLoaders.items():
|
||
|
logger.log(
|
||
|
"Evaluate ---->>>> {:10s} with {:} batchs".format(key, len(value))
|
||
|
)
|
||
|
results = evaluate_for_seed(
|
||
|
arch_config, config, arch, train_loader, ValLoaders, seed, logger
|
||
|
)
|
||
|
all_infos[dataset_key] = results
|
||
|
all_dataset_keys.append(dataset_key)
|
||
|
all_infos["all_dataset_keys"] = all_dataset_keys
|
||
|
return all_infos
|
||
|
|
||
|
def evaluate_for_seed(
|
||
|
arch_config, config, arch, train_loader, valid_loaders, seed, logger
|
||
|
):
|
||
|
|
||
|
prepare_seed(seed) # random seed
|
||
|
net = get_cell_based_tiny_net(
|
||
|
dict2config(
|
||
|
{
|
||
|
"name": "infer.tiny",
|
||
|
"C": arch_config["channel"],
|
||
|
"N": arch_config["num_cells"],
|
||
|
"genotype": arch,
|
||
|
"num_classes": config.class_num,
|
||
|
},
|
||
|
None,
|
||
|
)
|
||
|
)
|
||
|
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
|
||
|
flop, param = get_model_infos(net, config.xshape)
|
||
|
logger.log("Network : {:}".format(net.get_message()), False)
|
||
|
logger.log(
|
||
|
"{:} Seed-------------------------- {:} --------------------------".format(
|
||
|
time_string(), seed
|
||
|
)
|
||
|
)
|
||
|
logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param))
|
||
|
# train and valid
|
||
|
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config)
|
||
|
network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda()
|
||
|
# start training
|
||
|
start_time, epoch_time, total_epoch = (
|
||
|
time.time(),
|
||
|
AverageMeter(),
|
||
|
config.epochs + config.warmup,
|
||
|
)
|
||
|
(
|
||
|
train_losses,
|
||
|
train_acc1es,
|
||
|
train_acc5es,
|
||
|
valid_losses,
|
||
|
valid_acc1es,
|
||
|
valid_acc5es,
|
||
|
) = ({}, {}, {}, {}, {}, {})
|
||
|
train_times, valid_times = {}, {}
|
||
|
for epoch in range(total_epoch):
|
||
|
scheduler.update(epoch, 0.0)
|
||
|
|
||
|
train_loss, train_acc1, train_acc5, train_tm = procedure(
|
||
|
train_loader, network, criterion, scheduler, optimizer, "train"
|
||
|
)
|
||
|
train_losses[epoch] = train_loss
|
||
|
train_acc1es[epoch] = train_acc1
|
||
|
train_acc5es[epoch] = train_acc5
|
||
|
train_times[epoch] = train_tm
|
||
|
with torch.no_grad():
|
||
|
for key, xloder in valid_loaders.items():
|
||
|
valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(
|
||
|
xloder, network, criterion, None, None, "valid"
|
||
|
)
|
||
|
valid_losses["{:}@{:}".format(key, epoch)] = valid_loss
|
||
|
valid_acc1es["{:}@{:}".format(key, epoch)] = valid_acc1
|
||
|
valid_acc5es["{:}@{:}".format(key, epoch)] = valid_acc5
|
||
|
valid_times["{:}@{:}".format(key, epoch)] = valid_tm
|
||
|
|
||
|
# measure elapsed time
|
||
|
epoch_time.update(time.time() - start_time)
|
||
|
start_time = time.time()
|
||
|
need_time = "Time Left: {:}".format(
|
||
|
convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True)
|
||
|
)
|
||
|
logger.log(
|
||
|
"{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]".format(
|
||
|
time_string(),
|
||
|
need_time,
|
||
|
epoch,
|
||
|
total_epoch,
|
||
|
train_loss,
|
||
|
train_acc1,
|
||
|
train_acc5,
|
||
|
valid_loss,
|
||
|
valid_acc1,
|
||
|
valid_acc5,
|
||
|
)
|
||
|
)
|
||
|
info_seed = {
|
||
|
"flop": flop,
|
||
|
"param": param,
|
||
|
"channel": arch_config["channel"],
|
||
|
"num_cells": arch_config["num_cells"],
|
||
|
"config": config._asdict(),
|
||
|
"total_epoch": total_epoch,
|
||
|
"train_losses": train_losses,
|
||
|
"train_acc1es": train_acc1es,
|
||
|
"train_acc5es": train_acc5es,
|
||
|
"train_times": train_times,
|
||
|
"valid_losses": valid_losses,
|
||
|
"valid_acc1es": valid_acc1es,
|
||
|
"valid_acc5es": valid_acc5es,
|
||
|
"valid_times": valid_times,
|
||
|
"net_state_dict": net.state_dict(),
|
||
|
"net_string": "{:}".format(net),
|
||
|
"finish-train": True,
|
||
|
}
|
||
|
return info_seed
|
||
|
|
||
|
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
|
||
|
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
|
||
|
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||
|
latencies = []
|
||
|
network.eval()
|
||
|
with torch.no_grad():
|
||
|
end = time.time()
|
||
|
for i, (inputs, targets) in enumerate(xloader):
|
||
|
targets = targets.cuda(non_blocking=True)
|
||
|
inputs = inputs.cuda(non_blocking=True)
|
||
|
data_time.update(time.time() - end)
|
||
|
# forward
|
||
|
features, logits = network(inputs)
|
||
|
loss = criterion(logits, targets)
|
||
|
batch_time.update(time.time() - end)
|
||
|
if batch is None or batch == inputs.size(0):
|
||
|
batch = inputs.size(0)
|
||
|
latencies.append(batch_time.val - data_time.val)
|
||
|
# record loss and accuracy
|
||
|
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||
|
losses.update(loss.item(), inputs.size(0))
|
||
|
top1.update(prec1.item(), inputs.size(0))
|
||
|
top5.update(prec5.item(), inputs.size(0))
|
||
|
end = time.time()
|
||
|
if len(latencies) > 2:
|
||
|
latencies = latencies[1:]
|
||
|
return losses.avg, top1.avg, top5.avg, latencies
|
||
|
|
||
|
|
||
|
def procedure(xloader, network, criterion, scheduler, optimizer, mode):
|
||
|
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||
|
if mode == "train":
|
||
|
network.train()
|
||
|
elif mode == "valid":
|
||
|
network.eval()
|
||
|
else:
|
||
|
raise ValueError("The mode is not right : {:}".format(mode))
|
||
|
|
||
|
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
|
||
|
for i, (inputs, targets) in enumerate(xloader):
|
||
|
if mode == "train":
|
||
|
scheduler.update(None, 1.0 * i / len(xloader))
|
||
|
|
||
|
targets = targets.cuda(non_blocking=True)
|
||
|
if mode == "train":
|
||
|
optimizer.zero_grad()
|
||
|
# forward
|
||
|
features, logits = network(inputs)
|
||
|
loss = criterion(logits, targets)
|
||
|
# backward
|
||
|
if mode == "train":
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
# record loss and accuracy
|
||
|
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||
|
losses.update(loss.item(), inputs.size(0))
|
||
|
top1.update(prec1.item(), inputs.size(0))
|
||
|
top5.update(prec5.item(), inputs.size(0))
|
||
|
# count time
|
||
|
batch_time.update(time.time() - end)
|
||
|
end = time.time()
|
||
|
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
||
|
|
||
|
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
|
||
|
data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
|
||
|
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||
|
latencies = []
|
||
|
network.eval()
|
||
|
with torch.no_grad():
|
||
|
end = time.time()
|
||
|
for i, (inputs, targets) in enumerate(xloader):
|
||
|
targets = targets.cuda(non_blocking=True)
|
||
|
inputs = inputs.cuda(non_blocking=True)
|
||
|
data_time.update(time.time() - end)
|
||
|
# forward
|
||
|
features, logits = network(inputs)
|
||
|
loss = criterion(logits, targets)
|
||
|
batch_time.update(time.time() - end)
|
||
|
if batch is None or batch == inputs.size(0):
|
||
|
batch = inputs.size(0)
|
||
|
latencies.append(batch_time.val - data_time.val)
|
||
|
# record loss and accuracy
|
||
|
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||
|
losses.update(loss.item(), inputs.size(0))
|
||
|
top1.update(prec1.item(), inputs.size(0))
|
||
|
top5.update(prec5.item(), inputs.size(0))
|
||
|
end = time.time()
|
||
|
if len(latencies) > 2:
|
||
|
latencies = latencies[1:]
|
||
|
return losses.avg, top1.avg, top5.avg, latencies
|
||
|
|
||
|
|
||
|
def procedure(xloader, network, criterion, scheduler, optimizer, mode):
|
||
|
losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||
|
if mode == "train":
|
||
|
network.train()
|
||
|
elif mode == "valid":
|
||
|
network.eval()
|
||
|
else:
|
||
|
raise ValueError("The mode is not right : {:}".format(mode))
|
||
|
|
||
|
data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
|
||
|
for i, (inputs, targets) in enumerate(xloader):
|
||
|
if mode == "train":
|
||
|
scheduler.update(None, 1.0 * i / len(xloader))
|
||
|
|
||
|
targets = targets.cuda(non_blocking=True)
|
||
|
if mode == "train":
|
||
|
optimizer.zero_grad()
|
||
|
# forward
|
||
|
features, logits = network(inputs)
|
||
|
loss = criterion(logits, targets)
|
||
|
# backward
|
||
|
if mode == "train":
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
# record loss and accuracy
|
||
|
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||
|
losses.update(loss.item(), inputs.size(0))
|
||
|
top1.update(prec1.item(), inputs.size(0))
|
||
|
top5.update(prec5.item(), inputs.size(0))
|
||
|
# count time
|
||
|
batch_time.update(time.time() - end)
|
||
|
end = time.time()
|
||
|
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
||
|
|
||
|
def train_single_model(
|
||
|
save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config
|
||
|
):
|
||
|
assert torch.cuda.is_available(), "CUDA is not available."
|
||
|
torch.backends.cudnn.enabled = True
|
||
|
torch.backends.cudnn.deterministic = True
|
||
|
# torch.backends.cudnn.benchmark = True
|
||
|
torch.set_num_threads(workers)
|
||
|
|
||
|
save_dir = (
|
||
|
Path(save_dir)
|
||
|
/ "specifics"
|
||
|
/ "{:}-{:}-{:}-{:}".format(
|
||
|
"LESS" if use_less else "FULL",
|
||
|
model_str,
|
||
|
arch_config["channel"],
|
||
|
arch_config["num_cells"],
|
||
|
)
|
||
|
)
|
||
|
logger = Logger(str(save_dir), 0, False)
|
||
|
print(CellArchitectures)
|
||
|
if model_str in CellArchitectures:
|
||
|
arch = CellArchitectures[model_str]
|
||
|
logger.log(
|
||
|
"The model string is found in pre-defined architecture dict : {:}".format(
|
||
|
model_str
|
||
|
)
|
||
|
)
|
||
|
else:
|
||
|
try:
|
||
|
arch = CellStructure.str2structure(model_str)
|
||
|
except:
|
||
|
raise ValueError(
|
||
|
"Invalid model string : {:}. It can not be found or parsed.".format(
|
||
|
model_str
|
||
|
)
|
||
|
)
|
||
|
assert arch.check_valid_op(
|
||
|
get_search_spaces("cell", "nas-bench-201")
|
||
|
), "{:} has the invalid op.".format(arch)
|
||
|
logger.log("Start train-evaluate {:}".format(arch.tostr()))
|
||
|
logger.log("arch_config : {:}".format(arch_config))
|
||
|
|
||
|
start_time, seed_time = time.time(), AverageMeter()
|
||
|
for _is, seed in enumerate(seeds):
|
||
|
logger.log(
|
||
|
"\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------".format(
|
||
|
_is, len(seeds), seed
|
||
|
)
|
||
|
)
|
||
|
to_save_name = save_dir / "seed-{:04d}.pth".format(seed)
|
||
|
if to_save_name.exists():
|
||
|
logger.log(
|
||
|
"Find the existing file {:}, directly load!".format(to_save_name)
|
||
|
)
|
||
|
checkpoint = torch.load(to_save_name)
|
||
|
else:
|
||
|
logger.log(
|
||
|
"Does not find the existing file {:}, train and evaluate!".format(
|
||
|
to_save_name
|
||
|
)
|
||
|
)
|
||
|
checkpoint = evaluate_all_datasets(
|
||
|
arch,
|
||
|
datasets,
|
||
|
xpaths,
|
||
|
splits,
|
||
|
use_less,
|
||
|
seed,
|
||
|
arch_config,
|
||
|
workers,
|
||
|
logger,
|
||
|
)
|
||
|
torch.save(checkpoint, to_save_name)
|
||
|
# log information
|
||
|
logger.log("{:}".format(checkpoint["info"]))
|
||
|
all_dataset_keys = checkpoint["all_dataset_keys"]
|
||
|
for dataset_key in all_dataset_keys:
|
||
|
logger.log(
|
||
|
"\n{:} dataset : {:} {:}".format("-" * 15, dataset_key, "-" * 15)
|
||
|
)
|
||
|
dataset_info = checkpoint[dataset_key]
|
||
|
# logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
|
||
|
logger.log(
|
||
|
"Flops = {:} MB, Params = {:} MB".format(
|
||
|
dataset_info["flop"], dataset_info["param"]
|
||
|
)
|
||
|
)
|
||
|
logger.log("config : {:}".format(dataset_info["config"]))
|
||
|
logger.log(
|
||
|
"Training State (finish) = {:}".format(dataset_info["finish-train"])
|
||
|
)
|
||
|
last_epoch = dataset_info["total_epoch"] - 1
|
||
|
train_acc1es, train_acc5es = (
|
||
|
dataset_info["train_acc1es"],
|
||
|
dataset_info["train_acc5es"],
|
||
|
)
|
||
|
valid_acc1es, valid_acc5es = (
|
||
|
dataset_info["valid_acc1es"],
|
||
|
dataset_info["valid_acc5es"],
|
||
|
)
|
||
|
print(dataset_info["train_acc1es"])
|
||
|
print(dataset_info["train_acc5es"])
|
||
|
print(dataset_info["valid_acc1es"])
|
||
|
print(dataset_info["valid_acc5es"])
|
||
|
logger.log(
|
||
|
"Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%".format(
|
||
|
train_acc1es[last_epoch],
|
||
|
train_acc5es[last_epoch],
|
||
|
100 - train_acc1es[last_epoch],
|
||
|
valid_acc1es['ori-test@'+str(last_epoch)],
|
||
|
valid_acc5es['ori-test@'+str(last_epoch)],
|
||
|
100 - valid_acc1es['ori-test@'+str(last_epoch)],
|
||
|
)
|
||
|
)
|
||
|
# measure elapsed time
|
||
|
seed_time.update(time.time() - start_time)
|
||
|
start_time = time.time()
|
||
|
need_time = "Time Left: {:}".format(
|
||
|
convert_secs2time(seed_time.avg * (len(seeds) - _is - 1), True)
|
||
|
)
|
||
|
logger.log(
|
||
|
"\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}".format(
|
||
|
_is, len(seeds), seed, need_time
|
||
|
)
|
||
|
)
|
||
|
logger.close()
|
||
|
|
||
|
# |nor_conv_3x3~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|
|
||
|
train_strs = [best_cifar_10_str, best_cifar_100_str, best_ImageNet16_str]
|
||
|
train_single_model(
|
||
|
save_dir="./outputs",
|
||
|
workers=8,
|
||
|
datasets=["ImageNet16-120"],
|
||
|
xpaths="./datasets/imagenet16-120",
|
||
|
splits=[0, 0, 0],
|
||
|
use_less=False,
|
||
|
seeds=[777],
|
||
|
model_str=best_ImageNet16_str,
|
||
|
arch_config={"channel": 16, "num_cells": 8},)
|
||
|
train_single_model(
|
||
|
save_dir="./outputs",
|
||
|
workers=8,
|
||
|
datasets=["cifar10"],
|
||
|
xpaths="./datasets/cifar10",
|
||
|
splits=[0, 0, 0],
|
||
|
use_less=False,
|
||
|
seeds=[777],
|
||
|
model_str=best_cifar_10_str,
|
||
|
arch_config={"channel": 16, "num_cells": 8},)
|
||
|
train_single_model(
|
||
|
save_dir="./outputs",
|
||
|
workers=8,
|
||
|
datasets=["cifar100"],
|
||
|
xpaths="./datasets/cifar100",
|
||
|
splits=[0, 0, 0],
|
||
|
use_less=False,
|
||
|
seeds=[777],
|
||
|
model_str=best_cifar_100_str,
|
||
|
arch_config={"channel": 16, "num_cells": 8},)
|