diff --git a/README.md b/README.md index 1826655..4ecf143 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # Neural Architecture Search Without Training -This repository contains code for replicating our paper on NAS without training. +This repository contains code for replicating our paper on NAS without training. -## Setup +## Setup 1. Download the [datasets](https://drive.google.com/drive/folders/1L0Lzq8rWpZLPfiQGd6QR8q5xLV88emU7). 2. Download [NAS-Bench-201](https://drive.google.com/file/d/1OOfVPpt-lA4u2HJrXbgrRd42IbfvJMyE/view). @@ -10,7 +10,7 @@ This repository contains code for replicating our paper on NAS without training. We also refer the reader to instructions in the official [NASBench-201 README](https://github.com/D-X-Y/NAS-Bench-201). -## Reproducing our results +## Reproducing our results To reproduce our results: @@ -39,6 +39,18 @@ To try different sample sizes, simply change the `--n_samples` argument in the c Note that search times may vary from the reported result owing to hardware setup. + +## Plotting histograms + +In order to plot the histograms in Figure 1 of the paper, run: + +``` +python plot_histograms.py +``` +to produce: + +![alt text](results/histograms_cifar10val_batch256.png) + The code is licensed under the MIT licence. ## Acknowledgements diff --git a/plot_histograms.py b/plot_histograms.py new file mode 100644 index 0000000..1295024 --- /dev/null +++ b/plot_histograms.py @@ -0,0 +1,144 @@ +import os +import argparse +import random +import numpy as np + +import matplotlib.pyplot as plt +from datasets import get_datasets +from config_utils import load_config + +from nas_201_api import NASBench201API as API +from models import get_cell_based_tiny_net +import torch +import torch.nn as nn + + +def get_batch_jacobian(net, data_loader, device): + data_iterator = iter(data_loader) + x, target = next(data_iterator) + x = x.to(device) + net.zero_grad() + x.requires_grad_(True) + _, y = net(x) + y.backward(torch.ones_like(y)) + jacob = x.grad.detach() + return jacob, target.detach() + +def plot_hist(jacob, ax, colour): + xx = jacob.reshape(jacob.size(0), -1).cpu().numpy() + corrs = np.corrcoef(xx) + ax.hist(corrs.flatten(), bins=100, color=colour) + +def decide_plot(acc, plt_cts, num_rows, boundaries=[60., 70., 80., 90.]): + if acc < boundaries[0]: + plt_col = 0 + accrange = f'< {boundaries[0]}%' + elif acc < boundaries[1]: + plt_col = 1 + accrange = f'[{boundaries[0]}% , {boundaries[1]}%)' + elif acc < boundaries[2]: + plt_col = 2 + accrange = f'[{boundaries[1]}% , {boundaries[2]}%)' + elif acc < boundaries[3]: + accrange = f'[{boundaries[2]}% , {boundaries[3]}%)' + plt_col = 3 + else: + accrange = f'>= {boundaries[3]}%' + plt_col = 4 + + can_plot = False + plt_row = 0 + if plt_cts[plt_col] < num_rows: + can_plot = True + plt_row = plt_cts[plt_col] + plt_cts[plt_col] += 1 + + return can_plot, plt_row, plt_col, accrange + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Plot histograms of correlation matrix') + parser.add_argument('--data_loc', default='../datasets/cifar/', type=str, help='dataset folder') + parser.add_argument('--api_loc', default='../datasets/NAS-Bench-201-v1_1-096897.pth', + type=str, help='path to API') + parser.add_argument('--arch_start', default=0, type=int) + parser.add_argument('--arch_end', default=15625, type=int) + parser.add_argument('--seed', default=42, type=int) + parser.add_argument('--GPU', default='0', type=str) + parser.add_argument('--batch_size', default=256, type=int) + + args = parser.parse_args() + os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU + + # Reproducibility + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + ARCH_START = args.arch_start + ARCH_END = args.arch_end + + criterion = nn.CrossEntropyLoss() + train_data, valid_data, xshape, class_num = get_datasets('cifar10', args.data_loc, 0) + + cifar_split = load_config('config_utils/cifar-split.txt', None, None) + train_split, valid_split = cifar_split.train, cifar_split.valid + train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, + num_workers=0, pin_memory=True, sampler= torch.utils.data.sampler.SubsetRandomSampler(train_split)) + + scores = [] + accs = [] + + plot_shape = (25, 5) + num_plots = plot_shape[0]*plot_shape[1] + fig, axes = plt.subplots(*plot_shape, sharex=True, figsize=(9, 9) ) + plt_cts = [0 for i in range(plot_shape[1])] + + api = API(args.api_loc) + + archs = list(range(ARCH_START, ARCH_END)) + colours = ['#811F41', '#A92941', '#D15141', '#EF7941', '#F99C4B'] + + strs = [] + random.shuffle(archs) + for arch in archs: + try: + config = api.get_net_config(arch, 'cifar10') + archinfo = api.query_meta_info_by_index(arch) + acc = archinfo.get_metrics('cifar10-valid', 'x-valid')['accuracy'] + + network = get_cell_based_tiny_net(config) + network = network.to(device) + jacobs, labels = get_batch_jacobian(network, train_loader, device) + + boundaries = [60., 70., 80., 90.] + can_plt, row, col, accrange = decide_plot(acc, plt_cts, plot_shape[0], boundaries) + if not can_plt: + continue + axes[row, col].axis('off') + + plot_hist(jacobs, axes[row, col], colours[col]) + if row == 0: + axes[row, col].set_title(f'{accrange}') + + if row + 1 == plot_shape[0]: + axes[row, col].axis('on') + plt.setp(axes[row, col].get_xticklabels(), fontsize=12) + axes[row, col].spines["top"].set_visible(False) + axes[row, col].spines["right"].set_visible(False) + axes[row, col].spines["left"].set_visible(False) + axes[row, col].set_yticks([]) + + if sum(plt_cts) == num_plots: + plt.tight_layout() + plt.savefig(f'results/histograms_cifar10val_batch{args.batch_size}.png') + plt.show() + break + except Exception as e: + plt_cts[col] -= 1 + continue diff --git a/results/ImageNet16-120_500_10_1.t7 b/results/ImageNet16-120_500_10_1.t7 deleted file mode 100644 index 2d2e4b7..0000000 Binary files a/results/ImageNet16-120_500_10_1.t7 and /dev/null differ diff --git a/results/cifar10-valid_500_10_1.t7 b/results/cifar10-valid_500_10_1.t7 deleted file mode 100644 index eaecec6..0000000 Binary files a/results/cifar10-valid_500_10_1.t7 and /dev/null differ diff --git a/results/cifar100_500_10_1.t7 b/results/cifar100_500_10_1.t7 deleted file mode 100644 index 20ac5f6..0000000 Binary files a/results/cifar100_500_10_1.t7 and /dev/null differ diff --git a/results/cifar10_500_10_1.t7 b/results/cifar10_500_10_1.t7 deleted file mode 100644 index abfc18d..0000000 Binary files a/results/cifar10_500_10_1.t7 and /dev/null differ diff --git a/results/histograms_cifar10val_batch256.png b/results/histograms_cifar10val_batch256.png new file mode 100644 index 0000000..6dedc09 Binary files /dev/null and b/results/histograms_cifar10val_batch256.png differ