xautodl/exps/NATS-Bench/draw-ranks.py

186 lines
6.9 KiB
Python
Raw Normal View History

2020-12-01 15:25:23 +01:00
###############################################################
2021-01-25 14:48:14 +01:00
# NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021 #
2020-12-01 15:25:23 +01:00
# The code to draw Figure 2 / 3 / 4 / 5 in our paper. #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NATS-Bench/draw-ranks.py #
###############################################################
import os, sys, time, torch, argparse
import scipy
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict, OrderedDict
2021-03-17 10:25:58 +01:00
from copy import deepcopy
2020-12-01 15:25:23 +01:00
from pathlib import Path
import matplotlib
import seaborn as sns
2021-03-17 10:25:58 +01:00
matplotlib.use("agg")
2020-12-01 15:25:23 +01:00
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from xautodl.config_utils import dict2config, load_config
from xautodl.log_utils import time_string
from xautodl.models import get_cell_based_tiny_net
2020-12-01 15:25:23 +01:00
from nats_bench import create
2021-03-18 09:02:55 +01:00
name2label = {
"cifar10": "CIFAR-10",
"cifar100": "CIFAR-100",
"ImageNet16-120": "ImageNet-16-120",
}
def visualize_relative_info(vis_save_dir, search_space, indicator, topk):
2021-03-17 10:25:58 +01:00
vis_save_dir = vis_save_dir.resolve()
2021-03-18 09:02:55 +01:00
print(
"{:} start to visualize {:} with top-{:} information".format(
time_string(), search_space, topk
)
)
2021-03-17 10:25:58 +01:00
vis_save_dir.mkdir(parents=True, exist_ok=True)
cache_file_path = vis_save_dir / "cache-{:}-info.pth".format(search_space)
datasets = ["cifar10", "cifar100", "ImageNet16-120"]
if not cache_file_path.exists():
api = create(None, search_space, fast_mode=False, verbose=False)
all_infos = OrderedDict()
for index in range(len(api)):
all_info = OrderedDict()
for dataset in datasets:
info_less = api.get_more_info(index, dataset, hp="12", is_random=False)
2021-03-18 09:02:55 +01:00
info_more = api.get_more_info(
index, dataset, hp=api.full_train_epochs, is_random=False
)
all_info[dataset] = dict(
less=info_less["test-accuracy"], more=info_more["test-accuracy"]
)
2021-03-17 10:25:58 +01:00
all_infos[index] = all_info
torch.save(all_infos, cache_file_path)
print("{:} save all cache data into {:}".format(time_string(), cache_file_path))
else:
api = create(None, search_space, fast_mode=True, verbose=False)
all_infos = torch.load(cache_file_path)
2020-12-01 15:25:23 +01:00
2021-03-17 10:25:58 +01:00
dpi, width, height = 250, 5000, 1300
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 16, 16
2020-12-01 15:25:23 +01:00
2021-03-17 10:25:58 +01:00
fig, axs = plt.subplots(1, 3, figsize=figsize)
datasets = ["cifar10", "cifar100", "ImageNet16-120"]
2020-12-01 15:25:23 +01:00
2021-03-17 10:25:58 +01:00
def sub_plot_fn(ax, dataset, indicator):
performances = []
# pickup top 10% architectures
for _index in range(len(api)):
performances.append((all_infos[_index][dataset][indicator], _index))
performances = sorted(performances, reverse=True)
performances = performances[: int(len(api) * topk * 0.01)]
selected_indexes = [x[1] for x in performances]
print(
"{:} plot {:10s} with {:}, {:} architectures".format(
time_string(), dataset, indicator, len(selected_indexes)
)
)
standard_scores = []
random_scores = []
for idx in selected_indexes:
standard_scores.append(
api.get_more_info(
2021-03-18 09:02:55 +01:00
idx,
dataset,
hp=api.full_train_epochs if indicator == "more" else "12",
is_random=False,
2021-03-17 10:25:58 +01:00
)["test-accuracy"]
)
random_scores.append(
api.get_more_info(
2021-03-18 09:02:55 +01:00
idx,
dataset,
hp=api.full_train_epochs if indicator == "more" else "12",
is_random=True,
2021-03-17 10:25:58 +01:00
)["test-accuracy"]
)
indexes = list(range(len(selected_indexes)))
standard_indexes = sorted(indexes, key=lambda i: standard_scores[i])
random_indexes = sorted(indexes, key=lambda i: random_scores[i])
random_labels = []
for idx in standard_indexes:
random_labels.append(random_indexes.index(idx))
for tick in ax.get_xticklabels():
tick.set_fontsize(LabelSize - 3)
for tick in ax.get_yticklabels():
tick.set_rotation(25)
tick.set_fontsize(LabelSize - 3)
ax.set_xlim(0, len(indexes))
ax.set_ylim(0, len(indexes))
ax.set_yticks(np.arange(min(indexes), max(indexes), max(indexes) // 3))
ax.set_xticks(np.arange(min(indexes), max(indexes), max(indexes) // 5))
ax.scatter(indexes, random_labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
2021-03-18 09:02:55 +01:00
ax.scatter(
[-1],
[-1],
marker="o",
s=100,
c="tab:blue",
label="Average Over Multi-Trials",
)
ax.scatter(
[-1],
[-1],
marker="^",
s=100,
c="tab:green",
label="Randomly Selected Trial",
)
2020-12-01 15:25:23 +01:00
2021-03-17 10:25:58 +01:00
coef, p = scipy.stats.kendalltau(standard_scores, random_scores)
2021-03-18 09:02:55 +01:00
ax.set_xlabel(
"architecture ranking in {:}".format(name2label[dataset]),
fontsize=LabelSize,
)
2021-03-17 10:25:58 +01:00
if dataset == "cifar10":
ax.set_ylabel("architecture ranking", fontsize=LabelSize)
ax.legend(loc=4, fontsize=LegendFontsize)
return coef
2020-12-01 15:25:23 +01:00
2021-03-17 10:25:58 +01:00
for dataset, ax in zip(datasets, axs):
rank_coef = sub_plot_fn(ax, dataset, indicator)
2021-03-18 09:02:55 +01:00
print(
"sub-plot {:} on {:} done, the ranking coefficient is {:.4f}.".format(
dataset, search_space, rank_coef
)
)
2020-12-01 15:25:23 +01:00
2021-03-18 09:02:55 +01:00
save_path = (
vis_save_dir / "{:}-rank-{:}-top{:}.pdf".format(search_space, indicator, topk)
).resolve()
2021-03-17 10:25:58 +01:00
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
2021-03-18 09:02:55 +01:00
save_path = (
vis_save_dir / "{:}-rank-{:}-top{:}.png".format(search_space, indicator, topk)
).resolve()
2021-03-17 10:25:58 +01:00
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("Save into {:}".format(save_path))
2020-12-01 15:25:23 +01:00
2021-03-17 10:25:58 +01:00
if __name__ == "__main__":
2021-03-18 09:02:55 +01:00
parser = argparse.ArgumentParser(
description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
2021-03-17 10:25:58 +01:00
parser.add_argument(
"--save_dir",
type=str,
default="output/vis-nas-bench/rank-stability",
help="Folder to save checkpoints and log.",
)
args = parser.parse_args()
to_save_dir = Path(args.save_dir)
2020-12-01 15:25:23 +01:00
2021-03-17 10:25:58 +01:00
for topk in [1, 5, 10, 20]:
visualize_relative_info(to_save_dir, "tss", "more", topk)
visualize_relative_info(to_save_dir, "sss", "less", topk)
print("{:} : complete running this file : {:}".format(time_string(), __file__))