xautodl/exps/NAS-Bench-201/test-correlation.py

160 lines
6.7 KiB
Python
Raw Normal View History

2020-02-23 00:30:37 +01:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
2019-12-31 12:02:11 +01:00
########################################################
# python exps/NAS-Bench-201/test-correlation.py --api_path $HOME/.torch/NAS-Bench-201-v1_0-e61699.pth
2019-12-31 12:02:11 +01:00
########################################################
2020-03-13 22:00:54 +01:00
import sys, argparse
2019-12-31 12:02:11 +01:00
import numpy as np
from copy import deepcopy
2020-01-01 12:18:42 +01:00
from tqdm import tqdm
2019-12-31 12:02:11 +01:00
import torch
from pathlib import Path
2021-03-17 10:25:58 +01:00
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from log_utils import time_string
from models import CellStructure
from nas_201_api import NASBench201API as API
2019-12-31 12:02:11 +01:00
def check_unique_arch(meta_file):
2021-03-17 10:25:58 +01:00
api = API(str(meta_file))
arch_strs = deepcopy(api.meta_archs)
xarchs = [CellStructure.str2structure(x) for x in arch_strs]
2019-12-31 12:02:11 +01:00
2021-03-17 10:25:58 +01:00
def get_unique_matrix(archs, consider_zero):
UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs]
print("{:} create unique-string ({:}/{:}) done".format(time_string(), len(set(UniquStrs)), len(UniquStrs)))
Unique2Index = dict()
for index, xstr in enumerate(UniquStrs):
if xstr not in Unique2Index:
Unique2Index[xstr] = list()
Unique2Index[xstr].append(index)
sm_matrix = torch.eye(len(archs)).bool()
for _, xlist in Unique2Index.items():
for i in xlist:
for j in xlist:
sm_matrix[i, j] = True
unique_ids, unique_num = [-1 for _ in archs], 0
for i in range(len(unique_ids)):
if unique_ids[i] > -1:
continue
neighbours = sm_matrix[i].nonzero().view(-1).tolist()
for nghb in neighbours:
assert unique_ids[nghb] == -1, "impossible"
unique_ids[nghb] = unique_num
unique_num += 1
return sm_matrix, unique_ids, unique_num
print("There are {:} valid-archs".format(sum(arch.check_valid() for arch in xarchs)))
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, None)
print("{:} There are {:} unique architectures (considering nothing).".format(time_string(), unique_num))
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, False)
print("{:} There are {:} unique architectures (not considering zero).".format(time_string(), unique_num))
sm_matrix, uniqueIDs, unique_num = get_unique_matrix(xarchs, True)
print("{:} There are {:} unique architectures (considering zero).".format(time_string(), unique_num))
2019-12-31 12:02:11 +01:00
2020-01-01 12:18:42 +01:00
def check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand=True, need_print=False):
2021-03-17 10:25:58 +01:00
if isinstance(meta_file, API):
api = meta_file
else:
api = API(str(meta_file))
cifar10_currs = []
cifar10_valid = []
cifar10_test = []
cifar100_valid = []
cifar100_test = []
imagenet_test = []
imagenet_valid = []
for idx, arch in enumerate(api):
results = api.get_more_info(idx, "cifar10-valid", test_epoch - 1, use_less_or_not, is_rand)
cifar10_currs.append(results["valid-accuracy"])
# --->>>>>
results = api.get_more_info(idx, "cifar10-valid", None, False, is_rand)
cifar10_valid.append(results["valid-accuracy"])
results = api.get_more_info(idx, "cifar10", None, False, is_rand)
cifar10_test.append(results["test-accuracy"])
results = api.get_more_info(idx, "cifar100", None, False, is_rand)
cifar100_test.append(results["test-accuracy"])
cifar100_valid.append(results["valid-accuracy"])
results = api.get_more_info(idx, "ImageNet16-120", None, False, is_rand)
imagenet_test.append(results["test-accuracy"])
imagenet_valid.append(results["valid-accuracy"])
def get_cor(A, B):
return float(np.corrcoef(A, B)[0, 1])
cors = []
for basestr, xlist in zip(
["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"],
[cifar10_valid, cifar10_test, cifar100_valid, cifar100_test, imagenet_valid, imagenet_test],
):
correlation = get_cor(cifar10_currs, xlist)
if need_print:
print(
"With {:3d}/{:}-epochs-training, the correlation between cifar10-valid and {:} is : {:}".format(
test_epoch, "012" if use_less_or_not else "200", basestr, correlation
)
)
cors.append(correlation)
# print ('With {:3d}/200-epochs-training, the correlation between cifar10-valid and {:} is : {:}'.format(test_epoch, basestr, get_cor(cifar10_valid_200, xlist)))
# print('-'*200)
# print('*'*230)
return cors
2019-12-31 12:02:11 +01:00
2020-01-01 12:18:42 +01:00
def check_cor_for_bandit_v2(meta_file, test_epoch, use_less_or_not, is_rand):
2021-03-17 10:25:58 +01:00
corrs = []
for i in tqdm(range(100)):
x = check_cor_for_bandit(meta_file, test_epoch, use_less_or_not, is_rand, False)
corrs.append(x)
# xstrs = ['CIFAR-010', 'C-100-V', 'C-100-T', 'I16-V', 'I16-T']
xstrs = ["C-010-V", "C-010-T", "C-100-V", "C-100-T", "I16-V", "I16-T"]
correlations = np.array(corrs)
print("------>>>>>>>> {:03d}/{:} >>>>>>>> ------".format(test_epoch, "012" if use_less_or_not else "200"))
for idx, xstr in enumerate(xstrs):
print(
"{:8s} ::: mean={:.4f}, std={:.4f} :: {:.4f}\\pm{:.4f}".format(
xstr,
correlations[:, idx].mean(),
correlations[:, idx].std(),
correlations[:, idx].mean(),
correlations[:, idx].std(),
)
)
print("")
2020-01-01 12:18:42 +01:00
2021-03-17 10:25:58 +01:00
if __name__ == "__main__":
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
parser.add_argument(
"--save_dir",
type=str,
default="./output/search-cell-nas-bench-201/visuals",
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument("--api_path", type=str, default=None, help="The path to the NAS-Bench-201 benchmark file.")
args = parser.parse_args()
2019-12-31 12:02:11 +01:00
2021-03-17 10:25:58 +01:00
vis_save_dir = Path(args.save_dir)
vis_save_dir.mkdir(parents=True, exist_ok=True)
meta_file = Path(args.api_path)
assert meta_file.exists(), "invalid path for api : {:}".format(meta_file)
2019-12-31 12:02:11 +01:00
2021-03-17 10:25:58 +01:00
# check_unique_arch(meta_file)
api = API(str(meta_file))
# for iepoch in [11, 25, 50, 100, 150, 175, 200]:
# check_cor_for_bandit(api, 6, iepoch)
# check_cor_for_bandit(api, 12, iepoch)
check_cor_for_bandit_v2(api, 6, True, True)
check_cor_for_bandit_v2(api, 12, True, True)
check_cor_for_bandit_v2(api, 12, False, True)
check_cor_for_bandit_v2(api, 24, False, True)
check_cor_for_bandit_v2(api, 100, False, True)
check_cor_for_bandit_v2(api, 150, False, True)
check_cor_for_bandit_v2(api, 175, False, True)
check_cor_for_bandit_v2(api, 200, False, True)
print("----")