xautodl/exps/NAS-Bench-201/test-correlation.py

198 lines
7.0 KiB
Python

#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
########################################################
# python exps/NAS-Bench-201/test-correlation.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
########################################################
import sys, argparse
import numpy as np
from copy import deepcopy
from tqdm import tqdm
import torch
from pathlib import Path
from xautodl.log_utils import time_string
from xautodl.models import CellStructure
from nas_201_api import NASBench201API as API
def check_unique_arch(meta_file):
api = API(str(meta_file))
arch_strs = deepcopy(api.meta_archs)
xarchs = [CellStructure.str2structure(x) for x in arch_strs]
def get_unique_matrix(archs, consider_zero):
UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs]
print(
"{:} create unique-string ({:}/{:}) done".format(
time_string(), len(set(UniquStrs)), len(UniquStrs)
)
)
Unique2Index = dict()
for index, xstr in enumerate(UniquStrs):
if xstr not in Unique2Index:
Unique2Index[xstr] = list()
Unique2Index[xstr].append(index)
sm_matrix = torch.eye(len(archs)).bool()
for _, xlist in Unique2Index.items():
for i in xlist:
for j in xlist:
sm_matrix[i, j] = True
unique_ids, unique_num = [-1 for _ in archs], 0
for i in range(len(unique_ids)):
if unique_ids[i] > -1:
continue
neighbours = sm_matrix[i].nonzero().view(-1).tolist()
for nghb in neighbours:
assert unique_ids[nghb] == -1, "impossible"
unique_ids[nghb] = unique_num
unique_num += 1
return sm_matrix, unique_ids, unique_num
print(
"There are {:} valid-archs".format(sum(arch.check_valid() for arch in xarchs))
)
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None)
print(
"{:} There are {:} unique architectures (considering nothing).".format(
time_string(), unique_num
)
)
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False)
print(
"{:} There are {:} unique architectures (not considering zero).".format(
time_string(), unique_num
)
)
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, True)
print(
"{:} There are {:} unique architectures (considering zero).".format(
time_string(), unique_num
)
)
def check_cor_for_bandit(
meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False
):
if isinstance(meta_file, API):
api = meta_file
else:
api = API(str(meta_file))
cifar10_currs = []
cifar10_valid = []
cifar10_test = []
cifar100_valid = []
cifar100_test = []
imagenet_test = []
imagenet_valid = []
for idx, arch in enumerate(api):
results = api.get_more_info(
idx, "cifar10-valid", test_epoch - 1, use_less_or_not, is_rand
)
cifar10_currs.append(results["valid-accuracy"])
# --->>>>>
results = api.get_more_info(idx, "cifar10-valid", None, False, is_rand)
cifar10_valid.append(results["valid-accuracy"])
results = api.get_more_info(idx, "cifar10", None, False, is_rand)
cifar10_test.append(results["test-accuracy"])
results = api.get_more_info(idx, "cifar100", None, False, is_rand)
cifar100_test.append(results["test-accuracy"])
cifar100_valid.append(results["valid-accuracy"])
results = api.get_more_info(idx, "ImageNet16-120", None, False, is_rand)
imagenet_test.append(results["test-accuracy"])
imagenet_valid.append(results["valid-accuracy"])
def get_cor(A, B):
return float(np.corrcoef(A, B)[0, 1])
cors = []
for basestr, xlist in zip(
["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"],
[
cifar10_valid,
cifar10_test,
cifar100_valid,
cifar100_test,
imagenet_valid,
imagenet_test,
],
):
correlation = get_cor(cifar10_currs, xlist)
if need_print:
print(
"With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}".format(
test_epoch,
"012" if use_less_or_not else "200",
basestr,
correlation,
)
)
cors.append(correlation)
# print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist)))
# print('-'*200)
# print('*'*230)
return cors
def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand):
corrs = []
for i in tqdm(range(100)):
x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False)
corrs.append(x)
# xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
xstrs = ["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"]
correlations = np.array(corrs)
print(
"------>>>>>>>> {:03d}/{:} >>>>>>>> ------".format(
test_epoch, "012" if use_less_or_not else "200"
)
)
for idx, xstr in enumerate(xstrs):
print(
"{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}".format(
xstr,
correlations[:, idx].mean(),
correlations[:, idx].std(),
correlations[:, idx].mean(),
correlations[:, idx].std(),
)
)
print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
parser.add_argument(
"--save_dir",
type=str,
default="./output/search-cell-nas-bench-201/visuals",
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument(
"--api_path",
type=str,
default=None,
help="The path to the NAS-Bench-201 benchmark file.",
)
args = parser.parse_args()
vis_save_dir = Path(args.save_dir)
vis_save_dir.mkdir(parents=True, exist_ok=True)
meta_file = Path(args.api_path)
assert meta_file.exists(), "invalid path for api : {:}".format(meta_file)
# check_unique_arch(meta_file)
api = API(str(meta_file))
# for iepoch in [11, 25, 50, 100, 150, 175, 200]:
# check_cor_for_bandit(api, 6, iepoch)
# check_cor_for_bandit(api, 12, iepoch)
check_cor_for_bandit_v2(api, 6, True, True)
check_cor_for_bandit_v2(api, 12, True, True)
check_cor_for_bandit_v2(api, 12, False, True)
check_cor_for_bandit_v2(api, 24, False, True)
check_cor_for_bandit_v2(api, 100, False, True)
check_cor_for_bandit_v2(api, 150, False, True)
check_cor_for_bandit_v2(api, 175, False, True)
check_cor_for_bandit_v2(api, 200, False, True)
print("----")