xautodl/exps/experimental/test-ww-bench.py
2021-05-24 13:06:10 +08:00

199 lines
7.5 KiB
Python

#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
###########################################################################################################################################################
# Before run these commands, the files must be properly put.
#
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-tss-v1_0-3ffb9 --dataset cifar10
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-sss-v1_0-50262 --dataset cifar100
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-sss-v1_0-50262 --dataset ImageNet16-120
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space tss --base_path $HOME/.torch/NATS-tss-v1_0-3ffb9 --dataset cifar10
###########################################################################################################################################################
import os, gc, sys, math, argparse, psutil
import numpy as np
import torch
from pathlib import Path
from collections import OrderedDict
import matplotlib
import seaborn as sns
matplotlib.use("agg")
import matplotlib.pyplot as plt
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from log_utils import time_string
from nats_bench import create
from models import get_cell_based_tiny_net
from utils import weight_watcher
"""
def get_cor(A, B):
return float(np.corrcoef(A, B)[0,1])
def tostr(accdict, norms):
xstr = []
for key, accs in accdict.items():
cor = get_cor(accs, norms)
xstr.append('{:}: {:.3f}'.format(key, cor))
return ' '.join(xstr)
"""
def evaluate(api, weight_dir, data: str):
print("\nEvaluate dataset={:}".format(data))
process = psutil.Process(os.getpid())
norms, accuracies = [], []
ok, total = 0, 5000
for idx in range(total):
arch_index = api.random()
api.reload(weight_dir, arch_index)
# compute the weight watcher results
config = api.get_net_config(arch_index, data)
net = get_cell_based_tiny_net(config)
meta_info = api.query_meta_info_by_index(
arch_index, hp="200" if api.search_space_name == "topology" else "90"
)
params = meta_info.get_net_param(
data, 888 if api.search_space_name == "topology" else 777
)
with torch.no_grad():
net.load_state_dict(params)
_, summary = weight_watcher.analyze(net, alphas=False)
if "lognorm" not in summary:
api.clear_params(arch_index, None)
del net
continue
continue
cur_norm = -summary["lognorm"]
api.clear_params(arch_index, None)
if math.isnan(cur_norm):
del net, meta_info
continue
else:
ok += 1
norms.append(cur_norm)
# query the accuracy
info = meta_info.get_metrics(
data,
"ori-test",
iepoch=None,
is_random=888 if api.search_space_name == "topology" else 777,
)
accuracies.append(info["accuracy"])
del net, meta_info
# print the information
if idx % 20 == 0:
gc.collect()
print(
"{:} {:04d}_{:04d}/{:04d} ({:.2f} MB memory)".format(
time_string(), ok, idx, total, process.memory_info().rss / 1e6
)
)
return norms, accuracies
def main(search_space, meta_file: str, weight_dir, save_dir, xdata):
save_dir.mkdir(parents=True, exist_ok=True)
api = create(meta_file, search_space, verbose=False)
datasets = ["cifar10-valid", "cifar10", "cifar100", "ImageNet16-120"]
print(time_string() + " " + "=" * 50)
for data in datasets:
hps = api.avaliable_hps
for hp in hps:
nums = api.statistics(data, hp=hp)
total = sum([k * v for k, v in nums.items()])
print(
"Using {:3s} epochs, trained on {:20s} : {:} trials in total ({:}).".format(
hp, data, total, nums
)
)
print(time_string() + " " + "=" * 50)
norms, accuracies = evaluate(api, weight_dir, xdata)
indexes = list(range(len(norms)))
norm_indexes = sorted(indexes, key=lambda i: norms[i])
accy_indexes = sorted(indexes, key=lambda i: accuracies[i])
labels = []
for index in norm_indexes:
labels.append(accy_indexes.index(index))
dpi, width, height = 200, 1400, 800
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 18, 12
resnet_scale, resnet_alpha = 120, 0.5
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
# plt.ylabel('y').set_rotation(30)
plt.yticks(
np.arange(min(indexes), max(indexes), max(indexes) // 3),
fontsize=LegendFontsize,
rotation="vertical",
)
plt.xticks(
np.arange(min(indexes), max(indexes), max(indexes) // 5),
fontsize=LegendFontsize,
)
ax.scatter(indexes, labels, marker="*", s=0.5, c="tab:red", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="Test accuracy")
ax.scatter([-1], [-1], marker="*", s=100, c="tab:red", label="Weight watcher")
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc=0, fontsize=LegendFontsize)
ax.set_xlabel(
"architecture ranking sorted by the test accuracy ", fontsize=LabelSize
)
ax.set_ylabel("architecture ranking computed by weight watcher", fontsize=LabelSize)
save_path = (save_dir / "{:}-{:}-test-ww.pdf".format(search_space, xdata)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
save_path = (save_dir / "{:}-{:}-test-ww.png".format(search_space, xdata)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
print("{:} finish this test.".format(time_string()))
if __name__ == "__main__":
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
parser.add_argument(
"--save_dir",
type=str,
default="./output/vis-nas-bench/",
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument(
"--search_space",
type=str,
default=None,
choices=["tss", "sss"],
help="The search space.",
)
parser.add_argument(
"--base_path",
type=str,
default=None,
help="The path to the NAS-Bench-201 benchmark file and weight dir.",
)
parser.add_argument("--dataset", type=str, default=None, help=".")
args = parser.parse_args()
save_dir = Path(args.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
meta_file = Path(args.base_path + ".pth")
weight_dir = Path(args.base_path + "-full")
assert meta_file.exists(), "invalid path for api : {:}".format(meta_file)
assert (
weight_dir.exists() and weight_dir.is_dir()
), "invalid path for weight dir : {:}".format(weight_dir)
main(args.search_space, str(meta_file), weight_dir, save_dir, args.dataset)