############################################################### # NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021 # ############################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 # ############################################################### # Usage: python exps/NATS-Bench/draw-correlations.py # ############################################################### import os, gc, sys, time, scipy, torch, argparse import numpy as np from typing import List, Text, Dict, Any from shutil import copyfile from collections import defaultdict, OrderedDict from copy import deepcopy from pathlib import Path import matplotlib import seaborn as sns matplotlib.use("agg") import matplotlib.pyplot as plt import matplotlib.ticker as ticker from xautodl.config_utils import dict2config, load_config from xautodl.log_utils import time_string from nats_bench import create def get_valid_test_acc(api, arch, dataset): is_size_space = api.search_space_name == "size" if dataset == "cifar10": xinfo = api.get_more_info( arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False ) test_acc = xinfo["test-accuracy"] xinfo = api.get_more_info( arch, dataset="cifar10-valid", hp=90 if is_size_space else 200, is_random=False, ) valid_acc = xinfo["valid-accuracy"] else: xinfo = api.get_more_info( arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False ) valid_acc = xinfo["valid-accuracy"] test_acc = xinfo["test-accuracy"] return ( valid_acc, test_acc, "validation = {:.2f}, test = {:.2f}\n".format(valid_acc, test_acc), ) def compute_kendalltau(vectori, vectorj): # indexes = list(range(len(vectori))) # rank_1 = sorted(indexes, key=lambda i: vectori[i]) # rank_2 = sorted(indexes, key=lambda i: vectorj[i]) # import pdb; pdb.set_trace() coef, p = scipy.stats.kendalltau(vectori, vectorj) return coef def compute_spearmanr(vectori, vectorj): coef, p = scipy.stats.spearmanr(vectori, vectorj) return coef if __name__ == "__main__": parser = argparse.ArgumentParser( description="NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--save_dir", type=str, default="output/vis-nas-bench/nas-algos", help="Folder to save checkpoints and log.", ) parser.add_argument( "--search_space", type=str, choices=["tss", "sss"], help="Choose the search space.", ) args = parser.parse_args() save_dir = Path(args.save_dir) api = create(None, "tss", fast_mode=True, verbose=False) indexes = list(range(1, 10000, 300)) scores_1 = [] scores_2 = [] for index in indexes: valid_acc, test_acc, _ = get_valid_test_acc(api, index, "cifar10") scores_1.append(valid_acc) scores_2.append(test_acc) correlation = compute_kendalltau(scores_1, scores_2) print( "The kendall tau correlation of {:} samples : {:}".format( len(indexes), correlation ) ) correlation = compute_spearmanr(scores_1, scores_2) print( "The spearmanr correlation of {:} samples : {:}".format( len(indexes), correlation ) ) # scores_1 = ['{:.2f}'.format(x) for x in scores_1] # scores_2 = ['{:.2f}'.format(x) for x in scores_2] # print(', '.join(scores_1)) # print(', '.join(scores_2)) dpi, width, height = 250, 1000, 1000 figsize = width / float(dpi), height / float(dpi) LabelSize, LegendFontsize = 14, 14 fig, ax = plt.subplots(1, 1, figsize=figsize) ax.scatter(scores_1, scores_2, marker="^", s=0.5, c="tab:green", alpha=0.8) save_path = "/Users/xuanyidong/Desktop/test-temp-rank.png" fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png") plt.close("all")