Add histogram plotting code
This commit is contained in:
parent
de1baa10a8
commit
2a1bb3ecc1
12
README.md
12
README.md
@ -39,6 +39,18 @@ To try different sample sizes, simply change the `--n_samples` argument in the c
|
||||
|
||||
Note that search times may vary from the reported result owing to hardware setup.
|
||||
|
||||
|
||||
## Plotting histograms
|
||||
|
||||
In order to plot the histograms in Figure 1 of the paper, run:
|
||||
|
||||
```
|
||||
python plot_histograms.py
|
||||
```
|
||||
to produce:
|
||||
|
||||

|
||||
|
||||
The code is licensed under the MIT licence.
|
||||
|
||||
## Acknowledgements
|
||||
|
144
plot_histograms.py
Normal file
144
plot_histograms.py
Normal file
@ -0,0 +1,144 @@
|
||||
import os
|
||||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from datasets import get_datasets
|
||||
from config_utils import load_config
|
||||
|
||||
from nas_201_api import NASBench201API as API
|
||||
from models import get_cell_based_tiny_net
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def get_batch_jacobian(net, data_loader, device):
|
||||
data_iterator = iter(data_loader)
|
||||
x, target = next(data_iterator)
|
||||
x = x.to(device)
|
||||
net.zero_grad()
|
||||
x.requires_grad_(True)
|
||||
_, y = net(x)
|
||||
y.backward(torch.ones_like(y))
|
||||
jacob = x.grad.detach()
|
||||
return jacob, target.detach()
|
||||
|
||||
def plot_hist(jacob, ax, colour):
|
||||
xx = jacob.reshape(jacob.size(0), -1).cpu().numpy()
|
||||
corrs = np.corrcoef(xx)
|
||||
ax.hist(corrs.flatten(), bins=100, color=colour)
|
||||
|
||||
def decide_plot(acc, plt_cts, num_rows, boundaries=[60., 70., 80., 90.]):
|
||||
if acc < boundaries[0]:
|
||||
plt_col = 0
|
||||
accrange = f'< {boundaries[0]}%'
|
||||
elif acc < boundaries[1]:
|
||||
plt_col = 1
|
||||
accrange = f'[{boundaries[0]}% , {boundaries[1]}%)'
|
||||
elif acc < boundaries[2]:
|
||||
plt_col = 2
|
||||
accrange = f'[{boundaries[1]}% , {boundaries[2]}%)'
|
||||
elif acc < boundaries[3]:
|
||||
accrange = f'[{boundaries[2]}% , {boundaries[3]}%)'
|
||||
plt_col = 3
|
||||
else:
|
||||
accrange = f'>= {boundaries[3]}%'
|
||||
plt_col = 4
|
||||
|
||||
can_plot = False
|
||||
plt_row = 0
|
||||
if plt_cts[plt_col] < num_rows:
|
||||
can_plot = True
|
||||
plt_row = plt_cts[plt_col]
|
||||
plt_cts[plt_col] += 1
|
||||
|
||||
return can_plot, plt_row, plt_col, accrange
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Plot histograms of correlation matrix')
|
||||
parser.add_argument('--data_loc', default='../datasets/cifar/', type=str, help='dataset folder')
|
||||
parser.add_argument('--api_loc', default='../datasets/NAS-Bench-201-v1_1-096897.pth',
|
||||
type=str, help='path to API')
|
||||
parser.add_argument('--arch_start', default=0, type=int)
|
||||
parser.add_argument('--arch_end', default=15625, type=int)
|
||||
parser.add_argument('--seed', default=42, type=int)
|
||||
parser.add_argument('--GPU', default='0', type=str)
|
||||
parser.add_argument('--batch_size', default=256, type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU
|
||||
|
||||
# Reproducibility
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
ARCH_START = args.arch_start
|
||||
ARCH_END = args.arch_end
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
train_data, valid_data, xshape, class_num = get_datasets('cifar10', args.data_loc, 0)
|
||||
|
||||
cifar_split = load_config('config_utils/cifar-split.txt', None, None)
|
||||
train_split, valid_split = cifar_split.train, cifar_split.valid
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
|
||||
num_workers=0, pin_memory=True, sampler= torch.utils.data.sampler.SubsetRandomSampler(train_split))
|
||||
|
||||
scores = []
|
||||
accs = []
|
||||
|
||||
plot_shape = (25, 5)
|
||||
num_plots = plot_shape[0]*plot_shape[1]
|
||||
fig, axes = plt.subplots(*plot_shape, sharex=True, figsize=(9, 9) )
|
||||
plt_cts = [0 for i in range(plot_shape[1])]
|
||||
|
||||
api = API(args.api_loc)
|
||||
|
||||
archs = list(range(ARCH_START, ARCH_END))
|
||||
colours = ['#811F41', '#A92941', '#D15141', '#EF7941', '#F99C4B']
|
||||
|
||||
strs = []
|
||||
random.shuffle(archs)
|
||||
for arch in archs:
|
||||
try:
|
||||
config = api.get_net_config(arch, 'cifar10')
|
||||
archinfo = api.query_meta_info_by_index(arch)
|
||||
acc = archinfo.get_metrics('cifar10-valid', 'x-valid')['accuracy']
|
||||
|
||||
network = get_cell_based_tiny_net(config)
|
||||
network = network.to(device)
|
||||
jacobs, labels = get_batch_jacobian(network, train_loader, device)
|
||||
|
||||
boundaries = [60., 70., 80., 90.]
|
||||
can_plt, row, col, accrange = decide_plot(acc, plt_cts, plot_shape[0], boundaries)
|
||||
if not can_plt:
|
||||
continue
|
||||
axes[row, col].axis('off')
|
||||
|
||||
plot_hist(jacobs, axes[row, col], colours[col])
|
||||
if row == 0:
|
||||
axes[row, col].set_title(f'{accrange}')
|
||||
|
||||
if row + 1 == plot_shape[0]:
|
||||
axes[row, col].axis('on')
|
||||
plt.setp(axes[row, col].get_xticklabels(), fontsize=12)
|
||||
axes[row, col].spines["top"].set_visible(False)
|
||||
axes[row, col].spines["right"].set_visible(False)
|
||||
axes[row, col].spines["left"].set_visible(False)
|
||||
axes[row, col].set_yticks([])
|
||||
|
||||
if sum(plt_cts) == num_plots:
|
||||
plt.tight_layout()
|
||||
plt.savefig(f'results/histograms_cifar10val_batch{args.batch_size}.png')
|
||||
plt.show()
|
||||
break
|
||||
except Exception as e:
|
||||
plt_cts[col] -= 1
|
||||
continue
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
results/histograms_cifar10val_batch256.png
Normal file
BIN
results/histograms_cifar10val_batch256.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 30 KiB |
Loading…
Reference in New Issue
Block a user