Add histogram plotting code

This commit is contained in:
jack-willturner 2020-06-17 13:43:08 +01:00
parent de1baa10a8
commit 2a1bb3ecc1
7 changed files with 159 additions and 3 deletions

View File

@ -1,8 +1,8 @@
# Neural Architecture Search Without Training
This repository contains code for replicating our paper on NAS without training.
This repository contains code for replicating our paper on NAS without training.
## Setup
## Setup
1. Download the [datasets](https://drive.google.com/drive/folders/1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7).
2. Download [NAS-Bench-201](https://drive.google.com/file/d/1OOfVPpt-lA4u2HJrXbgrRd42IbfvJMyE/view).
@ -10,7 +10,7 @@ This repository contains code for replicating our paper on NAS without training.
We also refer the reader to instructions in the official [NASBench-201 README](https://github.com/D-X-Y/NAS-Bench-201).
## Reproducing our results
## Reproducing our results
To reproduce our results:
@ -39,6 +39,18 @@ To try different sample sizes, simply change the `--n_samples` argument in the c
Note that search times may vary from the reported result owing to hardware setup.
## Plotting histograms
In order to plot the histograms in Figure 1 of the paper, run:
```
python plot_histograms.py
```
to produce:
![alt text](results/histograms_cifar10val_batch256.png)
The code is licensed under the MIT licence.
## Acknowledgements

144
plot_histograms.py Normal file
View File

@ -0,0 +1,144 @@
import os
import argparse
import random
import numpy as np
import matplotlib.pyplot as plt
from datasets import get_datasets
from config_utils import load_config
from nas_201_api import NASBench201API as API
from models import get_cell_based_tiny_net
import torch
import torch.nn as nn
def get_batch_jacobian(net, data_loader, device):
data_iterator = iter(data_loader)
x, target = next(data_iterator)
x = x.to(device)
net.zero_grad()
x.requires_grad_(True)
_, y = net(x)
y.backward(torch.ones_like(y))
jacob = x.grad.detach()
return jacob, target.detach()
def plot_hist(jacob, ax, colour):
xx = jacob.reshape(jacob.size(0), -1).cpu().numpy()
corrs = np.corrcoef(xx)
ax.hist(corrs.flatten(), bins=100, color=colour)
def decide_plot(acc, plt_cts, num_rows, boundaries=[60., 70., 80., 90.]):
if acc < boundaries[0]:
plt_col = 0
accrange = f'< {boundaries[0]}%'
elif acc < boundaries[1]:
plt_col = 1
accrange = f'[{boundaries[0]}% , {boundaries[1]}%)'
elif acc < boundaries[2]:
plt_col = 2
accrange = f'[{boundaries[1]}% , {boundaries[2]}%)'
elif acc < boundaries[3]:
accrange = f'[{boundaries[2]}% , {boundaries[3]}%)'
plt_col = 3
else:
accrange = f'>= {boundaries[3]}%'
plt_col = 4
can_plot = False
plt_row = 0
if plt_cts[plt_col] < num_rows:
can_plot = True
plt_row = plt_cts[plt_col]
plt_cts[plt_col] += 1
return can_plot, plt_row, plt_col, accrange
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Plot histograms of correlation matrix')
parser.add_argument('--data_loc', default='../datasets/cifar/', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='../datasets/NAS-Bench-201-v1_1-096897.pth',
type=str, help='path to API')
parser.add_argument('--arch_start', default=0, type=int)
parser.add_argument('--arch_end', default=15625, type=int)
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--batch_size', default=256, type=int)
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU
# Reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ARCH_START = args.arch_start
ARCH_END = args.arch_end
criterion = nn.CrossEntropyLoss()
train_data, valid_data, xshape, class_num = get_datasets('cifar10', args.data_loc, 0)
cifar_split = load_config('config_utils/cifar-split.txt', None, None)
train_split, valid_split = cifar_split.train, cifar_split.valid
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
num_workers=0, pin_memory=True, sampler= torch.utils.data.sampler.SubsetRandomSampler(train_split))
scores = []
accs = []
plot_shape = (25, 5)
num_plots = plot_shape[0]*plot_shape[1]
fig, axes = plt.subplots(*plot_shape, sharex=True, figsize=(9, 9) )
plt_cts = [0 for i in range(plot_shape[1])]
api = API(args.api_loc)
archs = list(range(ARCH_START, ARCH_END))
colours = ['#811F41', '#A92941', '#D15141', '#EF7941', '#F99C4B']
strs = []
random.shuffle(archs)
for arch in archs:
try:
config = api.get_net_config(arch, 'cifar10')
archinfo = api.query_meta_info_by_index(arch)
acc = archinfo.get_metrics('cifar10-valid', 'x-valid')['accuracy']
network = get_cell_based_tiny_net(config)
network = network.to(device)
jacobs, labels = get_batch_jacobian(network, train_loader, device)
boundaries = [60., 70., 80., 90.]
can_plt, row, col, accrange = decide_plot(acc, plt_cts, plot_shape[0], boundaries)
if not can_plt:
continue
axes[row, col].axis('off')
plot_hist(jacobs, axes[row, col], colours[col])
if row == 0:
axes[row, col].set_title(f'{accrange}')
if row + 1 == plot_shape[0]:
axes[row, col].axis('on')
plt.setp(axes[row, col].get_xticklabels(), fontsize=12)
axes[row, col].spines["top"].set_visible(False)
axes[row, col].spines["right"].set_visible(False)
axes[row, col].spines["left"].set_visible(False)
axes[row, col].set_yticks([])
if sum(plt_cts) == num_plots:
plt.tight_layout()
plt.savefig(f'results/histograms_cifar10val_batch{args.batch_size}.png')
plt.show()
break
except Exception as e:
plt_cts[col] -= 1
continue

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB