import os os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' import argparse import torch import torch.nn as nn import numpy as np import pandas as pd from scipy import stats from src.utils.utilities import * from src.metrics.swap import SWAP from src.datasets.utilities import get_datasets from src.search_space.networks import * import time # NASBench-201 from nas_201_api import NASBench201API as API # xautodl from xautodl.models import get_cell_based_tiny_net # initalize nasbench-201 nas_201_path = 'datasets/NAS-Bench-201-v1_1-096897.pth' print(f'Loading NAS-Bench-201 from {nas_201_path}') start_time = time.time() api = API(nas_201_path) end_time = time.time() print(f'Loaded NAS-Bench-201 in {end_time - start_time:.2f} seconds') # Settings for console outputs import warnings warnings.simplefilter(action='ignore', category=FutureWarning) warnings.simplefilter(action='ignore', category=UserWarning) parser = argparse.ArgumentParser() # general setting parser.add_argument('--data_path', default="datasets", type=str, nargs='?', help='path to the image dataset (datasets or datasets/ILSVRC/Data/CLS-LOC)') parser.add_argument('--seed', default=0, type=int, help='random seed') parser.add_argument('--device', default="cuda", type=str, nargs='?', help='setup device (cpu, mps or cuda)') parser.add_argument('--repeats', default=32, type=int, nargs='?', help='times of calculating the training-free metric') parser.add_argument('--input_samples', default=16, type=int, nargs='?', help='input batch size for training-free metric') args = parser.parse_args() if __name__ == "__main__": device = torch.device(args.device) # arch_info = pd.read_csv(args.data_path+'/DARTS_archs_CIFAR10.csv', names=['genotype', 'valid_acc'], sep=',') train_data, _, _ = get_datasets('cifar10', args.data_path, (args.input_samples, 3, 32, 32), -1) train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.input_samples, num_workers=0, pin_memory=True) loader = iter(train_loader) inputs, _ = next(loader) results = [] nasbench_len = 15625 # for index, i in arch_info.iterrows(): for i in range(nasbench_len): # print(f'Evaluating network: {index}') print(f'Evaluating network: {i}') config = api.get_net_config(i, 'cifar10') network = get_cell_based_tiny_net(config) nas_results = api.query_by_index(i, 'cifar10') acc = nas_results[111].get_eval('ori-test') print(type(network)) start_time = time.time() # network = Network(3, 10, 1, eval(i.genotype)) network = network.to(device) end_time = time.time() print(f'Loaded network in {end_time - start_time:.2f} seconds') print(f'initiliazing SWAP') swap = SWAP(model=network, inputs=inputs, device=device, seed=args.seed) swap_score = [] print(f'Calculating SWAP score') start_time = time.time() for i in range(args.repeats): print(f'Iteration: {i+1}/{args.repeats}', end='\r') network = network.apply(network_weight_gaussian_init) swap.reinit() swap_score.append(swap.forward()) swap.clear() end_time = time.time() print(f'Average SWAP score: {np.mean(swap_score)}') print(f'Elapsed time: {end_time - start_time:.2f} seconds') results.append([np.mean(swap_score), acc, i]) results = pd.DataFrame(results, columns=['swap_score', 'valid_acc', 'index']) print() print(f'Spearman\'s Correlation Coefficient: {stats.spearmanr(results.swap_score, results.valid_acc)[0]}') results.to_csv('swap_results.csv', float_format='%.4f', index=False)