xautodl/exps/NATS-Bench/draw-fig2_5.py

652 lines
24 KiB
Python

###############################################################
# NATS-Bench (arxiv.org/pdf/2009.00437.pdf), IEEE TPAMI 2021 #
# The code to draw Figure 2 / 3 / 4 / 5 in our paper. #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/NATS-Bench/draw-fig2_5.py #
###############################################################
import os, sys, time, torch, argparse
import scipy
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from xautodl.config_utils import dict2config, load_config
from xautodl.log_utils import time_string
from xautodl.models import get_cell_based_tiny_net
from nats_bench import create
def visualize_relative_info(api, vis_save_dir, indicator):
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar10", indicator
)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar100", indicator
)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"ImageNet16-120", indicator
)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info["params"])))
print("{:} start to visualize relative ranking".format(time_string()))
cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info["test_accs"][i])
cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info["test_accs"][i])
imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info["test_accs"][i])
cifar100_labels, imagenet_labels = [], []
for idx in cifar010_ord_indexes:
cifar100_labels.append(cifar100_ord_indexes.index(idx))
imagenet_labels.append(imagenet_ord_indexes.index(idx))
print("{:} prepare data done.".format(time_string()))
dpi, width, height = 200, 1400, 800
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 18, 12
resnet_scale, resnet_alpha = 120, 0.5
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
# plt.ylabel('y').set_rotation(30)
plt.yticks(
np.arange(min(indexes), max(indexes), max(indexes) // 3),
fontsize=LegendFontsize,
rotation="vertical",
)
plt.xticks(
np.arange(min(indexes), max(indexes), max(indexes) // 5),
fontsize=LegendFontsize,
)
ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="CIFAR-10")
ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="CIFAR-100")
ax.scatter([-1], [-1], marker="*", s=100, c="tab:red", label="ImageNet-16-120")
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc=0, fontsize=LegendFontsize)
ax.set_xlabel("architecture ranking in CIFAR-10", fontsize=LabelSize)
ax.set_ylabel("architecture ranking", fontsize=LabelSize)
save_path = (vis_save_dir / "{:}-relative-rank.pdf".format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
save_path = (vis_save_dir / "{:}-relative-rank.png".format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
def visualize_sss_info(api, dataset, vis_save_dir):
vis_save_dir = vis_save_dir.resolve()
print("{:} start to visualize {:} information".format(time_string(), dataset))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cache_file_path = vis_save_dir / "{:}-cache-sss-info.pth".format(dataset)
if not cache_file_path.exists():
print("Do not find cache file : {:}".format(cache_file_path))
params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
for index in range(len(api)):
cost_info = api.get_cost_info(index, dataset, hp="90")
params.append(cost_info["params"])
flops.append(cost_info["flops"])
# accuracy
info = api.get_more_info(index, dataset, hp="90", is_random=False)
train_accs.append(info["train-accuracy"])
test_accs.append(info["test-accuracy"])
if dataset == "cifar10":
info = api.get_more_info(
index, "cifar10-valid", hp="90", is_random=False
)
valid_accs.append(info["valid-accuracy"])
else:
valid_accs.append(info["valid-accuracy"])
info = {
"params": params,
"flops": flops,
"train_accs": train_accs,
"valid_accs": valid_accs,
"test_accs": test_accs,
}
torch.save(info, cache_file_path)
else:
print("Find cache file : {:}".format(cache_file_path))
info = torch.load(cache_file_path)
params, flops, train_accs, valid_accs, test_accs = (
info["params"],
info["flops"],
info["train_accs"],
info["valid_accs"],
info["test_accs"],
)
print("{:} collect data done.".format(time_string()))
# pyramid = ['8:16:32:48:64', '8:8:16:32:48', '8:8:16:16:32', '8:8:16:16:48', '8:8:16:16:64', '16:16:32:32:64', '32:32:64:64:64']
pyramid = ["8:16:24:32:40", "8:16:32:48:64", "32:40:48:56:64"]
pyramid_indexes = [api.query_index_by_arch(x) for x in pyramid]
largest_indexes = [api.query_index_by_arch("64:64:64:64:64")]
indexes = list(range(len(params)))
dpi, width, height = 250, 8500, 1300
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 24, 24
# resnet_scale, resnet_alpha = 120, 0.5
xscale, xalpha = 120, 0.8
fig, axs = plt.subplots(1, 4, figsize=figsize)
# ax1, ax2, ax3, ax4, ax5 = axs
for ax in axs:
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f"))
for tick in ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax1, ax2, ax3, ax4 = axs
ax1.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue")
ax1.scatter(
[params[x] for x in pyramid_indexes],
[train_accs[x] for x in pyramid_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="Pyramid Structure",
alpha=xalpha,
)
ax1.scatter(
[params[x] for x in largest_indexes],
[train_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax1.set_xlabel("#parameters (MB)", fontsize=LabelSize)
ax1.set_ylabel("train accuracy (%)", fontsize=LabelSize)
ax1.legend(loc=4, fontsize=LegendFontsize)
ax2.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue")
ax2.scatter(
[flops[x] for x in pyramid_indexes],
[train_accs[x] for x in pyramid_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="Pyramid Structure",
alpha=xalpha,
)
ax2.scatter(
[flops[x] for x in largest_indexes],
[train_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax2.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
# ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize)
ax2.legend(loc=4, fontsize=LegendFontsize)
ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue")
ax3.scatter(
[params[x] for x in pyramid_indexes],
[test_accs[x] for x in pyramid_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="Pyramid Structure",
alpha=xalpha,
)
ax3.scatter(
[params[x] for x in largest_indexes],
[test_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize)
ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize)
ax3.legend(loc=4, fontsize=LegendFontsize)
ax4.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue")
ax4.scatter(
[flops[x] for x in pyramid_indexes],
[test_accs[x] for x in pyramid_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="Pyramid Structure",
alpha=xalpha,
)
ax4.scatter(
[flops[x] for x in largest_indexes],
[test_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
# ax4.set_ylabel('test accuracy (%)', fontsize=LabelSize)
ax4.legend(loc=4, fontsize=LegendFontsize)
save_path = vis_save_dir / "sss-{:}.png".format(dataset.lower())
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
def visualize_tss_info(api, dataset, vis_save_dir):
vis_save_dir = vis_save_dir.resolve()
print("{:} start to visualize {:} information".format(time_string(), dataset))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cache_file_path = vis_save_dir / "{:}-cache-tss-info.pth".format(dataset)
if not cache_file_path.exists():
print("Do not find cache file : {:}".format(cache_file_path))
params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
for index in range(len(api)):
cost_info = api.get_cost_info(index, dataset, hp="12")
params.append(cost_info["params"])
flops.append(cost_info["flops"])
# accuracy
info = api.get_more_info(index, dataset, hp="200", is_random=False)
train_accs.append(info["train-accuracy"])
test_accs.append(info["test-accuracy"])
if dataset == "cifar10":
info = api.get_more_info(
index, "cifar10-valid", hp="200", is_random=False
)
valid_accs.append(info["valid-accuracy"])
else:
valid_accs.append(info["valid-accuracy"])
print("")
info = {
"params": params,
"flops": flops,
"train_accs": train_accs,
"valid_accs": valid_accs,
"test_accs": test_accs,
}
torch.save(info, cache_file_path)
else:
print("Find cache file : {:}".format(cache_file_path))
info = torch.load(cache_file_path)
params, flops, train_accs, valid_accs, test_accs = (
info["params"],
info["flops"],
info["train_accs"],
info["valid_accs"],
info["test_accs"],
)
print("{:} collect data done.".format(time_string()))
resnet = [
"|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"
]
resnet_indexes = [api.query_index_by_arch(x) for x in resnet]
largest_indexes = [
api.query_index_by_arch(
"|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|"
)
]
indexes = list(range(len(params)))
dpi, width, height = 250, 8500, 1300
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 24, 24
# resnet_scale, resnet_alpha = 120, 0.5
xscale, xalpha = 120, 0.8
fig, axs = plt.subplots(1, 4, figsize=figsize)
for ax in axs:
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f"))
for tick in ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax1, ax2, ax3, ax4 = axs
ax1.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue")
ax1.scatter(
[params[x] for x in resnet_indexes],
[train_accs[x] for x in resnet_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="ResNet",
alpha=xalpha,
)
ax1.scatter(
[params[x] for x in largest_indexes],
[train_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax1.set_xlabel("#parameters (MB)", fontsize=LabelSize)
ax1.set_ylabel("train accuracy (%)", fontsize=LabelSize)
ax1.legend(loc=4, fontsize=LegendFontsize)
ax2.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue")
ax2.scatter(
[flops[x] for x in resnet_indexes],
[train_accs[x] for x in resnet_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="ResNet",
alpha=xalpha,
)
ax2.scatter(
[flops[x] for x in largest_indexes],
[train_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax2.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
# ax2.set_ylabel('train accuracy (%)', fontsize=LabelSize)
ax2.legend(loc=4, fontsize=LegendFontsize)
ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue")
ax3.scatter(
[params[x] for x in resnet_indexes],
[test_accs[x] for x in resnet_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="ResNet",
alpha=xalpha,
)
ax3.scatter(
[params[x] for x in largest_indexes],
[test_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize)
ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize)
ax3.legend(loc=4, fontsize=LegendFontsize)
ax4.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue")
ax4.scatter(
[flops[x] for x in resnet_indexes],
[test_accs[x] for x in resnet_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="ResNet",
alpha=xalpha,
)
ax4.scatter(
[flops[x] for x in largest_indexes],
[test_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
# ax4.set_ylabel('test accuracy (%)', fontsize=LabelSize)
ax4.legend(loc=4, fontsize=LegendFontsize)
save_path = vis_save_dir / "tss-{:}.png".format(dataset.lower())
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
def visualize_rank_info(api, vis_save_dir, indicator):
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar10", indicator
)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar100", indicator
)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"ImageNet16-120", indicator
)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info["params"])))
print("{:} start to visualize relative ranking".format(time_string()))
dpi, width, height = 250, 3800, 1200
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 14, 14
fig, axs = plt.subplots(1, 3, figsize=figsize)
ax1, ax2, ax3 = axs
def get_labels(info):
ord_test_indexes = sorted(indexes, key=lambda i: info["test_accs"][i])
ord_valid_indexes = sorted(indexes, key=lambda i: info["valid_accs"][i])
labels = []
for idx in ord_test_indexes:
labels.append(ord_valid_indexes.index(idx))
return labels
def plot_ax(labels, ax, name):
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
for tick in ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
tick.label.set_rotation(90)
ax.set_xlim(min(indexes), max(indexes))
ax.set_ylim(min(indexes), max(indexes))
ax.yaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 3))
ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5))
ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter(
[-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name)
)
ax.scatter(
[-1],
[-1],
marker="o",
s=100,
c="tab:blue",
label="{:} validation".format(name),
)
ax.legend(loc=4, fontsize=LegendFontsize)
ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize)
ax.set_ylabel("architecture ranking", fontsize=LabelSize)
labels = get_labels(cifar010_info)
plot_ax(labels, ax1, "CIFAR-10")
labels = get_labels(cifar100_info)
plot_ax(labels, ax2, "CIFAR-100")
labels = get_labels(imagenet_info)
plot_ax(labels, ax3, "ImageNet-16-120")
save_path = (
vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
save_path = (
vis_save_dir / "{:}-same-relative-rank.png".format(indicator)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
def compute_kendalltau(vectori, vectorj):
# indexes = list(range(len(vectori)))
# rank_1 = sorted(indexes, key=lambda i: vectori[i])
# rank_2 = sorted(indexes, key=lambda i: vectorj[i])
return scipy.stats.kendalltau(vectori, vectorj).correlation
def calculate_correlation(*vectors):
matrix = []
for i, vectori in enumerate(vectors):
x = []
for j, vectorj in enumerate(vectors):
# x.append(np.corrcoef(vectori, vectorj)[0,1])
x.append(compute_kendalltau(vectori, vectorj))
matrix.append(x)
return np.array(matrix)
def visualize_all_rank_info(api, vis_save_dir, indicator):
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar10", indicator
)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar100", indicator
)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"ImageNet16-120", indicator
)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info["params"])))
print("{:} start to visualize relative ranking".format(time_string()))
dpi, width, height = 250, 3200, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 14, 14
fig, axs = plt.subplots(1, 2, figsize=figsize)
ax1, ax2 = axs
sns_size, xformat = 15, ".2f"
CoRelMatrix = calculate_correlation(
cifar010_info["valid_accs"],
cifar010_info["test_accs"],
cifar100_info["valid_accs"],
cifar100_info["test_accs"],
imagenet_info["valid_accs"],
imagenet_info["test_accs"],
)
sns.heatmap(
CoRelMatrix,
annot=True,
annot_kws={"size": sns_size},
fmt=xformat,
linewidths=0.5,
ax=ax1,
xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
)
selected_indexes, acc_bar = [], 92
for i, acc in enumerate(cifar010_info["test_accs"]):
if acc > acc_bar:
selected_indexes.append(i)
cifar010_valid_accs = np.array(cifar010_info["valid_accs"])[selected_indexes]
cifar010_test_accs = np.array(cifar010_info["test_accs"])[selected_indexes]
cifar100_valid_accs = np.array(cifar100_info["valid_accs"])[selected_indexes]
cifar100_test_accs = np.array(cifar100_info["test_accs"])[selected_indexes]
imagenet_valid_accs = np.array(imagenet_info["valid_accs"])[selected_indexes]
imagenet_test_accs = np.array(imagenet_info["test_accs"])[selected_indexes]
CoRelMatrix = calculate_correlation(
cifar010_valid_accs,
cifar010_test_accs,
cifar100_valid_accs,
cifar100_test_accs,
imagenet_valid_accs,
imagenet_test_accs,
)
sns.heatmap(
CoRelMatrix,
annot=True,
annot_kws={"size": sns_size},
fmt=xformat,
linewidths=0.5,
ax=ax2,
xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
)
ax1.set_title("Correlation coefficient over ALL candidates")
ax2.set_title(
"Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar)
)
save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NATS-Bench", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--save_dir",
type=str,
default="output/vis-nas-bench",
help="Folder to save checkpoints and log.",
)
# use for train the model
args = parser.parse_args()
to_save_dir = Path(args.save_dir)
datasets = ["cifar10", "cifar100", "ImageNet16-120"]
# Figure 3 (a-c)
api_tss = create(None, "tss", verbose=True)
for xdata in datasets:
visualize_tss_info(api_tss, xdata, to_save_dir)
# Figure 3 (d-f)
api_sss = create(None, "size", verbose=True)
for xdata in datasets:
visualize_sss_info(api_sss, xdata, to_save_dir)
# Figure 2
visualize_relative_info(None, to_save_dir, "tss")
visualize_relative_info(None, to_save_dir, "sss")
# Figure 4
visualize_rank_info(None, to_save_dir, "tss")
visualize_rank_info(None, to_save_dir, "sss")
# Figure 5
visualize_all_rank_info(None, to_save_dir, "tss")
visualize_all_rank_info(None, to_save_dir, "sss")